Commit 22161866 authored by ltqin's avatar ltqin
Browse files

change d0 desc

parent bef0cb20
...@@ -150,10 +150,10 @@ int main(int argc, char* argv[]) ...@@ -150,10 +150,10 @@ int main(int argc, char* argv[])
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = false;
int G0 = 3; int G0 = 64;
int G1 = 2; int G1 = 12;
int M = 1024; int M = 512;
int N = 1024; int N = 512;
int K = 64; int K = 64;
int O = 64; int O = 64;
float alpha = 1; float alpha = 1;
...@@ -194,12 +194,11 @@ int main(int argc, char* argv[]) ...@@ -194,12 +194,11 @@ int main(int argc, char* argv[])
} }
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides{ std::vector<ck::index_t> a_gs_ms_ks_strides{M * G1 * K, M * K, K, 1}; // A layout [G0, M, G1, K]
M * G1 * K, K, G1 * K, 1}; // A layout [G0, M, G1, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K}; std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides{ std::vector<ck::index_t> b0_gs_ns_ks_strides{
N * G1 * K, K, G1 * K, 1}; // B0 layout [G0, N, G1, K] N * G1 * K, N * K, K, 1}; // B0 layout [G0, N, G1, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N}; std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides{ std::vector<ck::index_t> b1_gs_os_ns_strides{
...@@ -211,7 +210,7 @@ int main(int argc, char* argv[]) ...@@ -211,7 +210,7 @@ int main(int argc, char* argv[])
// D layout [G0, M, G1, N] // D layout [G0, M, G1, N]
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N}; std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d0_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1}; std::vector<ck::index_t> d0_gs_ms_ns_strides{G1 * N, N, 0, 1};
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
...@@ -224,6 +223,7 @@ int main(int argc, char* argv[]) ...@@ -224,6 +223,7 @@ int main(int argc, char* argv[])
std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl;
std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl; std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl;
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl;
std::cout << "d0_gs_ms_ns: " << d0_gs_ms_ns.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -255,7 +255,7 @@ int main(int argc, char* argv[]) ...@@ -255,7 +255,7 @@ int main(int argc, char* argv[])
DeviceMem a_device_buf(sizeof(ADataType) * G0 * G1 * M * K); DeviceMem a_device_buf(sizeof(ADataType) * G0 * G1 * M * K);
DeviceMem b0_device_buf(sizeof(B0DataType) * G0 * G1 * N * K); DeviceMem b0_device_buf(sizeof(B0DataType) * G0 * G1 * N * K);
DeviceMem d0_device_buf(sizeof(D0DataType) * G0 * G1 * M * N); DeviceMem d0_device_buf(sizeof(D0DataType) * G0 * G1 * N);
DeviceMem b1_device_buf(sizeof(B1DataType) * G0 * G1 * O * N); DeviceMem b1_device_buf(sizeof(B1DataType) * G0 * G1 * O * N);
DeviceMem c_device_buf(sizeof(CDataType) * G0 * G1 * M * O); DeviceMem c_device_buf(sizeof(CDataType) * G0 * G1 * M * O);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment