Commit 3991a1c1 authored by qinletao's avatar qinletao
Browse files

reorganize example

parent 7e8e54de
......@@ -46,7 +46,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if 1
#if 0
< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 1, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 7, 1>;
using ADataType = double;
using BDataType = double;
......@@ -59,10 +59,15 @@ using BDataType = float;
using CDataType = float;
using AccDataType = float;
#endif
// clang-format on
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
template <typename DataType>
std::ostream& show_2d_matrix(std::ostream& os, Tensor<DataType>& matrix)
......@@ -88,13 +93,13 @@ int main(int argc, char* argv[])
int nrepeat = 5;
// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t M = 32;
ck::index_t N = 32;
ck::index_t K = 4;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
ck::index_t StrideA = 4;
ck::index_t StrideB = 4;
ck::index_t StrideC = 32;
if(argc == 4)
{
......@@ -144,6 +149,7 @@ int main(int argc, char* argv[])
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "data type: " << typeid(ADataType{}).name() << std::endl;
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
......@@ -160,10 +166,8 @@ int main(int argc, char* argv[])
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
// a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
// b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
......@@ -225,7 +229,7 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument);
#if 1
#if 0
{
show_2d_matrix(std::cout << "a : ", a_m_k) << std::endl;
show_2d_matrix(std::cout << "b: ", b_k_n) << std::endl;
......
......@@ -49,14 +49,14 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::
ConvFwdDefault, // ConvForwardSpecialization
NumDimSpatial, // NumDimSpatial
256, // BlockSize
64, // MPerBlock
64, // NPerBlock
128, // MPerBlock
128, // NPerBlock
4, // K0PerBlock
2, // K1
16, // MPerXDL
16, // NPerXDL
2, // MXdlPerWave
2, // NXdlPerWave
4, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
......@@ -241,10 +241,8 @@ int main(int argc, char* argv[])
weights.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
break;
default:
// input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
weights.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
input.GenerateTensorValue(GeneratorTensor_1<InDataType>{1});
// weights.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
weights.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
}
DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment