"vscode:/vscode.git/clone" did not exist on "37b82b7e5484e510c65c01efb9a5421498e3db96"
Commit 3991a1c1 authored by qinletao's avatar qinletao
Browse files

reorganize example

parent 7e8e54de
...@@ -46,7 +46,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl ...@@ -46,7 +46,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if 1 #if 0
< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 1, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 7, 1>; < F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 1, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 7, 1>;
using ADataType = double; using ADataType = double;
using BDataType = double; using BDataType = double;
...@@ -59,10 +59,15 @@ using BDataType = float; ...@@ -59,10 +59,15 @@ using BDataType = float;
using CDataType = float; using CDataType = float;
using AccDataType = float; using AccDataType = float;
#endif #endif
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
template <typename DataType> template <typename DataType>
std::ostream& show_2d_matrix(std::ostream& os, Tensor<DataType>& matrix) std::ostream& show_2d_matrix(std::ostream& os, Tensor<DataType>& matrix)
...@@ -88,13 +93,13 @@ int main(int argc, char* argv[]) ...@@ -88,13 +93,13 @@ int main(int argc, char* argv[])
int nrepeat = 5; int nrepeat = 5;
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = 32;
ck::index_t N = 4096; ck::index_t N = 32;
ck::index_t K = 4096; ck::index_t K = 4;
ck::index_t StrideA = 4096; ck::index_t StrideA = 4;
ck::index_t StrideB = 4096; ck::index_t StrideB = 4;
ck::index_t StrideC = 4096; ck::index_t StrideC = 32;
if(argc == 4) if(argc == 4)
{ {
...@@ -144,6 +149,7 @@ int main(int argc, char* argv[]) ...@@ -144,6 +149,7 @@ int main(int argc, char* argv[])
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "data type: " << typeid(ADataType{}).name() << std::endl;
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
...@@ -160,10 +166,8 @@ int main(int argc, char* argv[]) ...@@ -160,10 +166,8 @@ int main(int argc, char* argv[])
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break; break;
default: default:
// a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
// b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
} }
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
...@@ -225,7 +229,7 @@ int main(int argc, char* argv[]) ...@@ -225,7 +229,7 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
#if 1 #if 0
{ {
show_2d_matrix(std::cout << "a : ", a_m_k) << std::endl; show_2d_matrix(std::cout << "a : ", a_m_k) << std::endl;
show_2d_matrix(std::cout << "b: ", b_k_n) << std::endl; show_2d_matrix(std::cout << "b: ", b_k_n) << std::endl;
......
...@@ -49,14 +49,14 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device:: ...@@ -49,14 +49,14 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::
ConvFwdDefault, // ConvForwardSpecialization ConvFwdDefault, // ConvForwardSpecialization
NumDimSpatial, // NumDimSpatial NumDimSpatial, // NumDimSpatial
256, // BlockSize 256, // BlockSize
64, // MPerBlock 128, // MPerBlock
64, // NPerBlock 128, // NPerBlock
4, // K0PerBlock 4, // K0PerBlock
2, // K1 2, // K1
16, // MPerXDL 16, // MPerXDL
16, // NPerXDL 16, // NPerXDL
2, // MXdlPerWave 4, // MXdlPerWave
2, // NXdlPerWave 4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder
...@@ -241,10 +241,8 @@ int main(int argc, char* argv[]) ...@@ -241,10 +241,8 @@ int main(int argc, char* argv[])
weights.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5}); weights.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
break; break;
default: default:
// input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
weights.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
input.GenerateTensorValue(GeneratorTensor_1<InDataType>{1}); input.GenerateTensorValue(GeneratorTensor_1<InDataType>{1});
// weights.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1}); weights.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
} }
DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment