Commit 4d2172a9 authored by ltqin's avatar ltqin
Browse files

raw not split version

parent 669df2d3
......@@ -31,9 +31,6 @@ using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default;
// clang-format off
using DeviceConvWrWInstance = ck::tensor_operation::device::
DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
......@@ -44,24 +41,23 @@ using DeviceConvWrWInstance = ck::tensor_operation::device::
InElementOp, // InElementwiseOperation
WeiElementOp, // WeiElementwiseOperation
OutElementOp, // OutElementwiseOperation
ConvFwdDefault, // ConvForwardSpecialization
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
128, // NPerBlock
4, // K0PerBlock
8, // K1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
2, // NXdlPerWave
S<4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 2, 1>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
......@@ -176,9 +172,9 @@ int main(int argc, char* argv[])
Tensor<InDataType> in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{}));
Tensor<WeiDataType> wei_k_c_y_x_host_result(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{}));
Tensor<WeiDataType> wei_k_c_y_x_device_result(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{}));
Tensor<OutDataType> out_n_k_ho_wo(
f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{}));
Tensor<WeiDataType> wei_k_c_y_x_device_result(
f_host_tensor_descriptor(K, C, Y, X, WeiLayout{}));
Tensor<OutDataType> out_n_k_ho_wo(f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{}));
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl;
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_host_result.mDesc << std::endl;
......@@ -197,9 +193,9 @@ int main(int argc, char* argv[])
}
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x_device_result.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) *
out_n_k_ho_wo.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) *
wei_k_c_y_x_device_result.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment