Commit 7b1ce567 authored by Jing Zhang's avatar Jing Zhang
Browse files

clean code

parent 97a5b74a
......@@ -80,12 +80,10 @@ struct DeviceGemmXdl
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
std::cout << "PadM = " << PadM << " M = " << M + PadM << std::endl;
const auto a_grid_desc_k0_m_k1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pad_transform(M, I0, PadM)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
......@@ -111,12 +109,10 @@ struct DeviceGemmXdl
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
std::cout << "PadN = " << PadN << " N = " << N + PadN << std::endl;
const auto b_grid_desc_k0_n_k1 =
transform_tensor_descriptor(b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pad_transform(N, I0, PadN)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
......@@ -141,7 +137,7 @@ struct DeviceGemmXdl
const auto c_grid_desc_m_n_ = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pad_transform(M, I0, PadM), make_pad_transform(N, I0, PadN)),
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -314,9 +310,10 @@ struct DeviceGemmXdl
float Run(const Argument& arg, int nrepeat = 1)
{
{
std::cout << "MPerBlock = " << MPerBlock << " NPerBlock = " << NPerBlock
<< " MXdlPerWave = " << MXdlPerWave << " NXdlPerWave = " << NXdlPerWave
<< std::endl;
std::cout << "BlockGemmShape: {" << MPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << "}, WaveGemmShape: {" << MXdlPerWave * MPerXDL << ", "
<< NXdlPerWave * NPerXDL << "} XDLGemmShape: {" << MPerXDL << ", "
<< NPerXDL << "}" << std::endl;
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
......
......@@ -10,6 +10,7 @@ include_directories(BEFORE
${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform
${PROJECT_SOURCE_DIR}/composable_kernel/include/driver
${PROJECT_SOURCE_DIR}/external/rocm/include
${PROJECT_SOURCE_DIR}/device_operation/include
)
set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp)
......
......@@ -5,15 +5,7 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp"
struct OpPassThrough
{
template <typename T>
__host__ __device__ constexpr T operator()(T v) const
{
return v;
}
};
#include "element_wise_operation.hpp"
template <ck::index_t BlockSize,
typename FloatAB,
......@@ -79,7 +71,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
using ElementwiseOperation = OpPassThrough;
using ElementwiseOperation = ck::tensor_operation::element_wise::PassThrough;
using GridwiseGemm =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
......@@ -166,7 +158,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
float ave_time = 0;
auto element_op_ = OpPassThrough{};
auto element_op_ = ElementwiseOperation{};
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
if(has_main_k0_block_loop)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment