Commit b571256f authored by ltqin's avatar ltqin
Browse files

fix bug after merge develop

parent f9c478e2
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl_skip_b_lds.hpp" #include "device_gemm_xdl_skip_b_lds.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "device_gemm_xdl_cshuffle.hpp" #include "device_gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
...@@ -82,7 +81,7 @@ using AccDataType = float; ...@@ -82,7 +81,7 @@ using AccDataType = float;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>; ReferenceGemm<ADataType, BDataType, CDataType, float, AElementOp, BElementOp, CElementOp>;
template <typename DataType> template <typename DataType>
std::ostream& show_2d_matrix(std::ostream& os, Tensor<DataType>& matrix) std::ostream& show_2d_matrix(std::ostream& os, Tensor<DataType>& matrix)
...@@ -104,7 +103,7 @@ int main(int argc, char* argv[]) ...@@ -104,7 +103,7 @@ int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = 0;
int init_method = 0; int init_method = 0;
int nrepeat = 5; bool time_kernel = false;
// GEMM shape // GEMM shape
#if 1 #if 1
...@@ -129,13 +128,13 @@ int main(int argc, char* argv[]) ...@@ -129,13 +128,13 @@ int main(int argc, char* argv[])
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 10) else if(argc == 10)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
...@@ -149,7 +148,7 @@ int main(int argc, char* argv[]) ...@@ -149,7 +148,7 @@ int main(int argc, char* argv[])
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0); exit(0);
} }
...@@ -228,7 +227,7 @@ int main(int argc, char* argv[]) ...@@ -228,7 +227,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, nrepeat); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
......
...@@ -283,7 +283,7 @@ struct DeviceGemmXdlSkipBLds ...@@ -283,7 +283,7 @@ struct DeviceGemmXdlSkipBLds
{ {
using Argument = DeviceGemmXdlSkipBLds::Argument; using Argument = DeviceGemmXdlSkipBLds::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
...@@ -332,8 +332,8 @@ struct DeviceGemmXdlSkipBLds ...@@ -332,8 +332,8 @@ struct DeviceGemmXdlSkipBLds
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -364,8 +364,8 @@ struct DeviceGemmXdlSkipBLds ...@@ -364,8 +364,8 @@ struct DeviceGemmXdlSkipBLds
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -385,9 +385,10 @@ struct DeviceGemmXdlSkipBLds ...@@ -385,9 +385,10 @@ struct DeviceGemmXdlSkipBLds
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp" #include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops_skip_b_lds.hpp" #include "blockwise_gemm_xdlops_skip_b_lds.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp" #include "gridwise_gemm_pipeline_v1.hpp"
...@@ -124,6 +124,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -124,6 +124,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{};
static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
...@@ -397,33 +399,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 ...@@ -397,33 +399,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock * MultiK0, MPerBlock, K1>, Sequence<K0PerBlock * MultiK0, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_grid_desc_k0_m_k1), decltype(a_grid_desc_k0_m_k1),
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
2, 2,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true, true,
1>(a_grid_desc_k0_m_k1, 1>(
make_multi_index(0, m_block_data_idx_on_grid, 0), a_grid_desc_k0_m_k1,
a_element_op, make_multi_index(0, m_block_data_idx_on_grid, 0),
a_block_desc_k0_m_k1, a_element_op,
make_multi_index(0, 0, 0), a_block_desc_k0_m_k1,
ck::tensor_operation::element_wise::PassThrough{}); make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
ignore = b_element_op; ignore = b_element_op;
// B matrix threadwise copy // B matrix threadwise copy
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment