Commit ed424975 authored by Anthony Chang's avatar Anthony Chang
Browse files

tidy up example

parent 5f94555b
...@@ -76,7 +76,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_X ...@@ -76,7 +76,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_X
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 32, // KPerBlock
128, // Gemm1NPerBlock 128, // Gemm1NPerBlock
32, // Gemm1KPerBlock 32, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
...@@ -129,29 +129,19 @@ int main(int argc, char* argv[]) ...@@ -129,29 +129,19 @@ int main(int argc, char* argv[])
bool time_kernel = false; bool time_kernel = false;
// GEMM shape // GEMM shape
// ck::index_t M = 1024; ck::index_t M = 1024;
// ck::index_t N = 1024; ck::index_t N = 1024;
// ck::index_t K = 64; ck::index_t K = 64;
// ck::index_t O = 64; ck::index_t O = 128;
// ck::index_t BatchCount = 4; ck::index_t BatchCount = 4;
// ck::index_t StrideA = 1024; ck::index_t StrideA = -1;
// ck::index_t StrideB0 = 1024; ck::index_t StrideB0 = -1;
// ck::index_t StrideB1 = 1024; ck::index_t StrideB1 = -1;
// ck::index_t StrideC = 1024; ck::index_t StrideC = -1;
ck::index_t BatchStrideA = -1;
ck::index_t M = 256;
ck::index_t N = 128;
ck::index_t K = 32;
ck::index_t O = 128;
ck::index_t BatchCount = 4;
ck::index_t StrideA = 32;
ck::index_t StrideB0 = 32;
ck::index_t StrideB1 = 128;
ck::index_t StrideC = 128;
ck::index_t BatchStrideA = -1;
ck::index_t BatchStrideB0 = -1; ck::index_t BatchStrideB0 = -1;
ck::index_t BatchStrideB1 = -1; ck::index_t BatchStrideB1 = -1;
ck::index_t BatchStrideC = -1; ck::index_t BatchStrideC = -1;
if(argc == 1) if(argc == 1)
{ {
...@@ -163,6 +153,19 @@ int main(int argc, char* argv[]) ...@@ -163,6 +153,19 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 9)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
BatchCount = std::stoi(argv[8]);
}
else if(argc == 17) else if(argc == 17)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
...@@ -191,7 +194,8 @@ int main(int argc, char* argv[]) ...@@ -191,7 +194,8 @@ int main(int argc, char* argv[])
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); printf("arg4 to 17: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, "
"BatchStrideB0, BatchStrideB1, BatchStrideC\n");
exit(0); exit(0);
} }
...@@ -205,10 +209,10 @@ int main(int argc, char* argv[]) ...@@ -205,10 +209,10 @@ int main(int argc, char* argv[])
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1; StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
StrideC = (StrideC < 0) ? DefaultStrideC : StrideC; StrideC = (StrideC < 0) ? DefaultStrideC : StrideC;
const int DefaultBatchStrideA = (ck::is_same_v<ALayout, Row> ? K : M) * StrideA; const int DefaultBatchStrideA = (ck::is_same_v<ALayout, Col> ? K : M) * StrideA;
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Row> ? N : K) * StrideB0; const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Row> ? O : N) * StrideB1; const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
const int DefaultBatchStrideC = (ck::is_same_v<CLayout, Row> ? O : M) * StrideC; const int DefaultBatchStrideC = (ck::is_same_v<CLayout, Col> ? O : M) * StrideC;
BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA; BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA;
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0; BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
...@@ -234,49 +238,49 @@ int main(int argc, char* argv[]) ...@@ -234,49 +238,49 @@ int main(int argc, char* argv[])
}; };
// C_m_o = A_m_k * B0_k_n * B1_n_o // C_m_o = A_m_k * B0_k_n * B1_n_o
Tensor<ADataType> a_m_k( Tensor<ADataType> a_g_m_k(
f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{})); f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
Tensor<B0DataType> b0_k_n( Tensor<B0DataType> b0_g_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{})); f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
Tensor<B1DataType> b1_n_o( Tensor<B1DataType> b1_g_n_o(
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{})); f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<CDataType> c_m_o_host_result( Tensor<CDataType> c_g_m_o_host_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{})); f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{}));
Tensor<CDataType> c_m_o_device_result( Tensor<CDataType> c_g_m_o_device_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{})); f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
std::cout << "b1_n_o: " << b1_n_o.mDesc << std::endl; std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
std::cout << "c_m_o: " << c_m_o_host_result.mDesc << std::endl; std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5}); b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
break; break;
case 2: case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0}); b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b1_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
} }
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace());
DeviceMem b0_k_n_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpace()); DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpace());
DeviceMem b1_n_o_device_buf(sizeof(B1DataType) * b1_n_o.mDesc.GetElementSpace()); DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpace());
DeviceMem c_m_o_device_buf(sizeof(CDataType) * c_m_o_device_result.mDesc.GetElementSpace()); DeviceMem c_g_m_o_device_buf(sizeof(CDataType) * c_g_m_o_device_result.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data());
b0_k_n_device_buf.ToDevice(b0_k_n.mData.data()); b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
b1_n_o_device_buf.ToDevice(b1_n_o.mData.data()); b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{}; auto b0_element_op = B0ElementOp{};
...@@ -286,10 +290,10 @@ int main(int argc, char* argv[]) ...@@ -286,10 +290,10 @@ int main(int argc, char* argv[])
// do GEMM // do GEMM
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_k_n_device_buf.GetDeviceBuffer()), static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_n_o_device_buf.GetDeviceBuffer()), static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_o_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()),
M, M,
N, N,
K, K,
...@@ -329,28 +333,28 @@ int main(int argc, char* argv[]) ...@@ -329,28 +333,28 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl; << gemm.GetTypeString() << std::endl;
c_m_o_device_buf.FromDevice(c_m_o_device_result.mData.data()); c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
if(do_verification) if(do_verification)
{ {
// Output of Gemm0 is input A of Gemm1 // Output of Gemm0 is input A of Gemm1
Tensor<ADataType> a1_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
auto ref_gemm0 = ReferenceGemm0Instance{}; auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument( auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_m_k, b0_k_n, a1_m_n, a_element_op, b0_element_op, PassThrough{}); a_g_m_k, b0_g_k_n, a1_g_m_n, a_element_op, b0_element_op, PassThrough{});
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
auto ref_gemm1 = ReferenceGemm1Instance{}; auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument( auto ref_gemm1_argument = ref_gemm1.MakeArgument(
a1_m_n, b1_n_o, c_m_o_host_result, PassThrough{}, b1_element_op, c_element_op); a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument); ref_gemm1_invoker.Run(ref_gemm1_argument);
return ck::utils::check_err(c_m_o_device_result.mData, c_m_o_host_result.mData) ? 0 : 1; return ck::utils::check_err(c_g_m_o_device_result.mData, c_g_m_o_host_result.mData) ? 0 : 1;
} }
return 0; return 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment