"...composable_kernel_rocm.git" did not exist on "c203bf67117ca06b1dbd45b1f88e49c9b8a41db9"
Commit ed424975 authored by Anthony Chang's avatar Anthony Chang
Browse files

tidy up example

parent 5f94555b
......@@ -76,7 +76,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_X
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
128, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
......@@ -129,29 +129,19 @@ int main(int argc, char* argv[])
bool time_kernel = false;
// GEMM shape
// ck::index_t M = 1024;
// ck::index_t N = 1024;
// ck::index_t K = 64;
// ck::index_t O = 64;
// ck::index_t BatchCount = 4;
// ck::index_t StrideA = 1024;
// ck::index_t StrideB0 = 1024;
// ck::index_t StrideB1 = 1024;
// ck::index_t StrideC = 1024;
ck::index_t M = 256;
ck::index_t N = 128;
ck::index_t K = 32;
ck::index_t O = 128;
ck::index_t BatchCount = 4;
ck::index_t StrideA = 32;
ck::index_t StrideB0 = 32;
ck::index_t StrideB1 = 128;
ck::index_t StrideC = 128;
ck::index_t BatchStrideA = -1;
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 64;
ck::index_t O = 128;
ck::index_t BatchCount = 4;
ck::index_t StrideA = -1;
ck::index_t StrideB0 = -1;
ck::index_t StrideB1 = -1;
ck::index_t StrideC = -1;
ck::index_t BatchStrideA = -1;
ck::index_t BatchStrideB0 = -1;
ck::index_t BatchStrideB1 = -1;
ck::index_t BatchStrideC = -1;
ck::index_t BatchStrideC = -1;
if(argc == 1)
{
......@@ -163,6 +153,19 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 9)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
BatchCount = std::stoi(argv[8]);
}
else if(argc == 17)
{
do_verification = std::stoi(argv[1]);
......@@ -191,7 +194,8 @@ int main(int argc, char* argv[])
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
printf("arg4 to 17: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, "
"BatchStrideB0, BatchStrideB1, BatchStrideC\n");
exit(0);
}
......@@ -205,10 +209,10 @@ int main(int argc, char* argv[])
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
StrideC = (StrideC < 0) ? DefaultStrideC : StrideC;
const int DefaultBatchStrideA = (ck::is_same_v<ALayout, Row> ? K : M) * StrideA;
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Row> ? N : K) * StrideB0;
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Row> ? O : N) * StrideB1;
const int DefaultBatchStrideC = (ck::is_same_v<CLayout, Row> ? O : M) * StrideC;
const int DefaultBatchStrideA = (ck::is_same_v<ALayout, Col> ? K : M) * StrideA;
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
const int DefaultBatchStrideC = (ck::is_same_v<CLayout, Col> ? O : M) * StrideC;
BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA;
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
......@@ -234,49 +238,49 @@ int main(int argc, char* argv[])
};
// C_m_o = A_m_k * B0_k_n * B1_n_o
Tensor<ADataType> a_m_k(
Tensor<ADataType> a_g_m_k(
f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
Tensor<B0DataType> b0_k_n(
Tensor<B0DataType> b0_g_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
Tensor<B1DataType> b1_n_o(
Tensor<B1DataType> b1_g_n_o(
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<CDataType> c_m_o_host_result(
Tensor<CDataType> c_g_m_o_host_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{}));
Tensor<CDataType> c_m_o_device_result(
Tensor<CDataType> c_g_m_o_device_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
std::cout << "b1_n_o: " << b1_n_o.mDesc << std::endl;
std::cout << "c_m_o: " << c_m_o_host_result.mDesc << std::endl;
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b1_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b0_k_n_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpace());
DeviceMem b1_n_o_device_buf(sizeof(B1DataType) * b1_n_o.mDesc.GetElementSpace());
DeviceMem c_m_o_device_buf(sizeof(CDataType) * c_m_o_device_result.mDesc.GetElementSpace());
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpace());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpace());
DeviceMem c_g_m_o_device_buf(sizeof(CDataType) * c_g_m_o_device_result.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b0_k_n_device_buf.ToDevice(b0_k_n.mData.data());
b1_n_o_device_buf.ToDevice(b1_n_o.mData.data());
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
......@@ -286,10 +290,10 @@ int main(int argc, char* argv[])
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_k_n_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_n_o_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_o_device_buf.GetDeviceBuffer()),
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()),
M,
N,
K,
......@@ -329,28 +333,28 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
c_m_o_device_buf.FromDevice(c_m_o_device_result.mData.data());
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
if(do_verification)
{
// Output of Gemm0 is input A of Gemm1
Tensor<ADataType> a1_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_m_k, b0_k_n, a1_m_n, a_element_op, b0_element_op, PassThrough{});
a_g_m_k, b0_g_k_n, a1_g_m_n, a_element_op, b0_element_op, PassThrough{});
ref_gemm0_invoker.Run(ref_gemm0_argument);
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
a1_m_n, b1_n_o, c_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument);
return ck::utils::check_err(c_m_o_device_result.mData, c_m_o_host_result.mData) ? 0 : 1;
return ck::utils::check_err(c_g_m_o_device_result.mData, c_g_m_o_host_result.mData) ? 0 : 1;
}
return 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment