"include/vscode:/vscode.git/clone" did not exist on "5b57ab96a8208eec1969a3dcadb555a6246ddb95"
Commit 7b1ce567 authored by Jing Zhang's avatar Jing Zhang
Browse files

clean code

parent 97a5b74a
...@@ -80,12 +80,10 @@ struct DeviceGemmXdl ...@@ -80,12 +80,10 @@ struct DeviceGemmXdl
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
std::cout << "PadM = " << PadM << " M = " << M + PadM << std::endl;
const auto a_grid_desc_k0_m_k1 = const auto a_grid_desc_k0_m_k1 =
transform_tensor_descriptor(a_grid_desc_m_k, transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pad_transform(M, I0, PadM)), make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
...@@ -111,12 +109,10 @@ struct DeviceGemmXdl ...@@ -111,12 +109,10 @@ struct DeviceGemmXdl
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
std::cout << "PadN = " << PadN << " N = " << N + PadN << std::endl;
const auto b_grid_desc_k0_n_k1 = const auto b_grid_desc_k0_n_k1 =
transform_tensor_descriptor(b_grid_desc_k_n, transform_tensor_descriptor(b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pad_transform(N, I0, PadN)), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
...@@ -141,7 +137,7 @@ struct DeviceGemmXdl ...@@ -141,7 +137,7 @@ struct DeviceGemmXdl
const auto c_grid_desc_m_n_ = transform_tensor_descriptor( const auto c_grid_desc_m_n_ = transform_tensor_descriptor(
c_grid_desc_m_n, c_grid_desc_m_n,
make_tuple(make_pad_transform(M, I0, PadM), make_pad_transform(N, I0, PadN)), make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -314,9 +310,10 @@ struct DeviceGemmXdl ...@@ -314,9 +310,10 @@ struct DeviceGemmXdl
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, int nrepeat = 1)
{ {
{ {
std::cout << "MPerBlock = " << MPerBlock << " NPerBlock = " << NPerBlock std::cout << "BlockGemmShape: {" << MPerBlock << ", " << NPerBlock << ", "
<< " MXdlPerWave = " << MXdlPerWave << " NXdlPerWave = " << NXdlPerWave << K0PerBlock << "}, WaveGemmShape: {" << MXdlPerWave * MPerXDL << ", "
<< std::endl; << NXdlPerWave * NPerXDL << "} XDLGemmShape: {" << MPerXDL << ", "
<< NPerXDL << "}" << std::endl;
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
......
...@@ -10,6 +10,7 @@ include_directories(BEFORE ...@@ -10,6 +10,7 @@ include_directories(BEFORE
${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform ${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform
${PROJECT_SOURCE_DIR}/composable_kernel/include/driver ${PROJECT_SOURCE_DIR}/composable_kernel/include/driver
${PROJECT_SOURCE_DIR}/external/rocm/include ${PROJECT_SOURCE_DIR}/external/rocm/include
${PROJECT_SOURCE_DIR}/device_operation/include
) )
set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp) set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp)
......
...@@ -5,15 +5,7 @@ ...@@ -5,15 +5,7 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp" #include "gridwise_gemm_xdlops_v2r3.hpp"
#include "element_wise_operation.hpp"
struct OpPassThrough
{
template <typename T>
__host__ __device__ constexpr T operator()(T v) const
{
return v;
}
};
template <ck::index_t BlockSize, template <ck::index_t BlockSize,
typename FloatAB, typename FloatAB,
...@@ -79,7 +71,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -79,7 +71,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
using ElementwiseOperation = OpPassThrough; using ElementwiseOperation = ck::tensor_operation::element_wise::PassThrough;
using GridwiseGemm = using GridwiseGemm =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize, GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
...@@ -166,7 +158,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -166,7 +158,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
float ave_time = 0; float ave_time = 0;
auto element_op_ = OpPassThrough{}; auto element_op_ = ElementwiseOperation{};
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment