"...resnet50_tensorflow.git" did not exist on "bcbce005922c44efd3c5bda5e8c6e811f0fd419e"
Commit 35b07efb authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed ci

parent 661b166e
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "device_gemm.hpp" #include "device_grouped_gemm.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
namespace ck { namespace ck {
...@@ -141,9 +141,9 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -141,9 +141,9 @@ void profile_grouped_gemm_impl(int do_verification,
p_b.reserve(group_count); p_b.reserve(group_count);
p_c.reserve(group_count); p_c.reserve(group_count);
std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes; std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
gemm_shapes.reserve(group_count); gemm_descs.reserve(group_count);
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
...@@ -159,7 +159,7 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -159,7 +159,7 @@ void profile_grouped_gemm_impl(int do_verification,
b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
c_device_buf[i]->ToDevice(c_m_n_device_results[i].mData.data()); c_device_buf[i]->ToDevice(c_m_n_device_results[i].mData.data());
gemm_shapes.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i]}); gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i]});
p_a.push_back(a_device_buf[i]->GetDeviceBuffer()); p_a.push_back(a_device_buf[i]->GetDeviceBuffer());
p_b.push_back(b_device_buf[i]->GetDeviceBuffer()); p_b.push_back(b_device_buf[i]->GetDeviceBuffer());
...@@ -221,7 +221,7 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -221,7 +221,7 @@ void profile_grouped_gemm_impl(int do_verification,
gemm_ptr->MakeArgumentPointer(p_a, gemm_ptr->MakeArgumentPointer(p_a,
p_b, p_b,
p_c, p_c,
gemm_shapes, gemm_descs,
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
...@@ -236,7 +236,7 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -236,7 +236,7 @@ void profile_grouped_gemm_impl(int do_verification,
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = 0, num_btype = 0; std::size_t flop = 0, num_btype = 0;
for(std::size_t i = 0; i < gemm_shapes.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];
...@@ -260,7 +260,7 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -260,7 +260,7 @@ void profile_grouped_gemm_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
for(std::size_t i = 0; i < gemm_shapes.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment