Commit df22ba01 authored by ltqin's avatar ltqin
Browse files

start to use atomic add

parent 162359b6
......@@ -50,23 +50,23 @@ using DeviceConvWrWInstance = ck::tensor_operation::device::
32, // NPerXdl
2, // MXdlPerWave
2, // NXdlPerWave
S<4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
1, // ABlockTransferSrcVectorDim
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
2, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM
S<4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
1, // BBlockTransferSrcVectorDim
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
2, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
S<1, 16, 1, 4>, //
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
......@@ -82,7 +82,7 @@ int main(int argc, char* argv[])
// Conv shape
ck::index_t N = 128;
ck::index_t K = 256;
ck::index_t C = 192;
ck::index_t C = 128;
ck::index_t Y = 3;
ck::index_t X = 3;
ck::index_t Hi = 71;
......@@ -95,6 +95,7 @@ int main(int argc, char* argv[])
ck::index_t in_left_pad_w = 1;
ck::index_t in_right_pad_h = 1;
ck::index_t in_right_pad_w = 1;
ck::index_t split_k = 1;
if(argc == 4)
{
......@@ -102,7 +103,7 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]);
}
else if(argc == 19)
else if(argc == 20)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
......@@ -123,6 +124,7 @@ int main(int argc, char* argv[])
in_left_pad_w = std::stoi(argv[16]);
in_right_pad_h = std::stoi(argv[17]);
in_right_pad_w = std::stoi(argv[18]);
split_k = std::stoi(argv[19]);
}
else
{
......@@ -185,12 +187,13 @@ int main(int argc, char* argv[])
case 0: break;
case 1:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
break;
default:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{1});
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
}
wei_k_c_y_x_device_result.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{0});
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) *
......@@ -199,6 +202,9 @@ int main(int argc, char* argv[])
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x_device_result.mData.data());
LogRangeAsType<float>(std::cout << "wei_device(before): ", wei_k_c_y_x_device_result.mData, ",")
<< std::endl;
// do GEMM
auto conv = DeviceConvWrWInstance{};
......@@ -218,7 +224,8 @@ int main(int argc, char* argv[])
input_right_pads,
InElementOp{},
WeiElementOp{},
OutElementOp{});
OutElementOp{},
split_k);
if(!conv.IsSupportedArgument(argument))
{
......@@ -262,6 +269,16 @@ int main(int argc, char* argv[])
wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data());
if(1)
{
LogRangeAsType<float>(std::cout << "out: ", out_n_k_ho_wo.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "in : ", in_n_c_hi_wi.mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "wei_device(after): ", wei_k_c_y_x_device_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",")
<< std::endl;
}
check_error(wei_k_c_y_x_host_result, wei_k_c_y_x_device_result);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment