Commit df22ba01 authored by ltqin's avatar ltqin
Browse files

start to use atomic add

parent 162359b6
...@@ -50,23 +50,23 @@ using DeviceConvWrWInstance = ck::tensor_operation::device:: ...@@ -50,23 +50,23 @@ using DeviceConvWrWInstance = ck::tensor_operation::device::
32, // NPerXdl 32, // NPerXdl
2, // MXdlPerWave 2, // MXdlPerWave
2, // NXdlPerWave 2, // NXdlPerWave
S<4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
1, // ABlockTransferSrcVectorDim 2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector 8, // ABlockTransferSrcScalarPerVector
2, // ABlockTransferDstScalarPerVector_K1 2, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM true, // ABlockLdsAddExtraM
S<4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1 S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
1, // BBlockTransferSrcVectorDim 2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferSrcScalarPerVector
2, // BBlockTransferDstScalarPerVector_K1 2, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN true, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl S<1, 16, 1, 4>, //
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on // clang-format on
...@@ -82,7 +82,7 @@ int main(int argc, char* argv[]) ...@@ -82,7 +82,7 @@ int main(int argc, char* argv[])
// Conv shape // Conv shape
ck::index_t N = 128; ck::index_t N = 128;
ck::index_t K = 256; ck::index_t K = 256;
ck::index_t C = 192; ck::index_t C = 128;
ck::index_t Y = 3; ck::index_t Y = 3;
ck::index_t X = 3; ck::index_t X = 3;
ck::index_t Hi = 71; ck::index_t Hi = 71;
...@@ -95,6 +95,7 @@ int main(int argc, char* argv[]) ...@@ -95,6 +95,7 @@ int main(int argc, char* argv[])
ck::index_t in_left_pad_w = 1; ck::index_t in_left_pad_w = 1;
ck::index_t in_right_pad_h = 1; ck::index_t in_right_pad_h = 1;
ck::index_t in_right_pad_w = 1; ck::index_t in_right_pad_w = 1;
ck::index_t split_k = 1;
if(argc == 4) if(argc == 4)
{ {
...@@ -102,7 +103,7 @@ int main(int argc, char* argv[]) ...@@ -102,7 +103,7 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
nrepeat = std::stoi(argv[3]); nrepeat = std::stoi(argv[3]);
} }
else if(argc == 19) else if(argc == 20)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -123,6 +124,7 @@ int main(int argc, char* argv[]) ...@@ -123,6 +124,7 @@ int main(int argc, char* argv[])
in_left_pad_w = std::stoi(argv[16]); in_left_pad_w = std::stoi(argv[16]);
in_right_pad_h = std::stoi(argv[17]); in_right_pad_h = std::stoi(argv[17]);
in_right_pad_w = std::stoi(argv[18]); in_right_pad_w = std::stoi(argv[18]);
split_k = std::stoi(argv[19]);
} }
else else
{ {
...@@ -185,12 +187,13 @@ int main(int argc, char* argv[]) ...@@ -185,12 +187,13 @@ int main(int argc, char* argv[])
case 0: break; case 0: break;
case 1: case 1:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5}); out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
break; break;
default: default:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0}); in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{1});
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5}); out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
} }
wei_k_c_y_x_device_result.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{0});
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * DeviceMem wei_device_buf(sizeof(WeiDataType) *
...@@ -199,6 +202,9 @@ int main(int argc, char* argv[]) ...@@ -199,6 +202,9 @@ int main(int argc, char* argv[])
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x_device_result.mData.data());
LogRangeAsType<float>(std::cout << "wei_device(before): ", wei_k_c_y_x_device_result.mData, ",")
<< std::endl;
// do GEMM // do GEMM
auto conv = DeviceConvWrWInstance{}; auto conv = DeviceConvWrWInstance{};
...@@ -218,7 +224,8 @@ int main(int argc, char* argv[]) ...@@ -218,7 +224,8 @@ int main(int argc, char* argv[])
input_right_pads, input_right_pads,
InElementOp{}, InElementOp{},
WeiElementOp{}, WeiElementOp{},
OutElementOp{}); OutElementOp{},
split_k);
if(!conv.IsSupportedArgument(argument)) if(!conv.IsSupportedArgument(argument))
{ {
...@@ -262,6 +269,16 @@ int main(int argc, char* argv[]) ...@@ -262,6 +269,16 @@ int main(int argc, char* argv[])
wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data()); wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data());
if(1)
{
LogRangeAsType<float>(std::cout << "out: ", out_n_k_ho_wo.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "in : ", in_n_c_hi_wi.mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "wei_device(after): ", wei_k_c_y_x_device_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",")
<< std::endl;
}
check_error(wei_k_c_y_x_host_result, wei_k_c_y_x_device_result); check_error(wei_k_c_y_x_host_result, wei_k_c_y_x_device_result);
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment