Commit b571256f authored by ltqin's avatar ltqin
Browse files

fix bug after merge develop

parent f9c478e2
......@@ -13,7 +13,6 @@
#include "device_tensor.hpp"
#include "device_gemm_xdl_skip_b_lds.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "device_gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
......@@ -82,7 +81,7 @@ using AccDataType = float;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
ReferenceGemm<ADataType, BDataType, CDataType, float, AElementOp, BElementOp, CElementOp>;
template <typename DataType>
std::ostream& show_2d_matrix(std::ostream& os, Tensor<DataType>& matrix)
......@@ -104,7 +103,7 @@ int main(int argc, char* argv[])
{
bool do_verification = 0;
int init_method = 0;
int nrepeat = 5;
bool time_kernel = false;
// GEMM shape
#if 1
......@@ -129,13 +128,13 @@ int main(int argc, char* argv[])
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
......@@ -149,7 +148,7 @@ int main(int argc, char* argv[])
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0);
}
......@@ -228,7 +227,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, nrepeat);
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
......
......@@ -283,7 +283,7 @@ struct DeviceGemmXdlSkipBLds
{
using Argument = DeviceGemmXdlSkipBLds::Argument;
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
......@@ -332,8 +332,8 @@ struct DeviceGemmXdlSkipBLds
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
......@@ -364,8 +364,8 @@ struct DeviceGemmXdlSkipBLds
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
......@@ -385,9 +385,10 @@ struct DeviceGemmXdlSkipBLds
}
// polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
......
......@@ -7,7 +7,7 @@
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops_skip_b_lds.hpp"
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
......@@ -124,6 +124,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{};
static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{
constexpr auto max_lds_align = K1;
......@@ -397,7 +399,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4r1<BlockSize,
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
......@@ -418,7 +420,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
1>(a_grid_desc_k0_m_k1,
1>(
a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0_m_k1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment