Commit cb2b5c86 authored by rocking5566's avatar rocking5566 Committed by rocking
Browse files

Merge branch 'develop' into group_norm

parents 7cda0a07 43c898f6
...@@ -10,4 +10,5 @@ struct StreamConfig ...@@ -10,4 +10,5 @@ struct StreamConfig
{ {
hipStream_t stream_id_ = nullptr; hipStream_t stream_id_ = nullptr;
bool time_kernel_ = false; bool time_kernel_ = false;
int log_level_ = 0;
}; };
...@@ -606,9 +606,11 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -606,9 +606,11 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 1 if(stream_config.log_level_ > 0)
arg.Print(); {
#endif arg.Print();
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_, arg.ds_grid_desc_m_n_,
......
...@@ -49,7 +49,7 @@ template <typename XDataType, ...@@ -49,7 +49,7 @@ template <typename XDataType,
typename BetaDataType, typename BetaDataType,
typename AccDataType, typename AccDataType,
typename YDataType> typename YDataType>
void profile_groupnorm_impl(int do_verification, bool profile_groupnorm_impl(int do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
bool time_kernel, bool time_kernel,
...@@ -61,7 +61,7 @@ void profile_groupnorm_impl(int do_verification, ...@@ -61,7 +61,7 @@ void profile_groupnorm_impl(int do_verification,
using Sigmoid = ck::tensor_operation::element_wise::Sigmoid; using Sigmoid = ck::tensor_operation::element_wise::Sigmoid;
if(length.size() != 5) if(length.size() != 5)
return; return false;
index_t G = length[3]; index_t G = length[3];
index_t C = length[4]; index_t C = length[4];
...@@ -156,6 +156,8 @@ void profile_groupnorm_impl(int do_verification, ...@@ -156,6 +156,8 @@ void profile_groupnorm_impl(int do_verification,
} }
} }
int num_kernel = 0;
for(auto& inst_ptr : instances) for(auto& inst_ptr : instances)
{ {
auto argument_ptr = inst_ptr->MakeArgumentPointer( auto argument_ptr = inst_ptr->MakeArgumentPointer(
...@@ -172,12 +174,13 @@ void profile_groupnorm_impl(int do_verification, ...@@ -172,12 +174,13 @@ void profile_groupnorm_impl(int do_verification,
y_dev.GetDeviceBuffer(), y_dev.GetDeviceBuffer(),
Sigmoid{}); Sigmoid{});
if(!inst_ptr->IsSupportedArgument(argument_ptr.get())) if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: "; ++num_kernel;
LogRange(std::cout << "input lengths = [", length, "], ") << std::endl; }
else
return; {
continue;
} }
auto invoker_ptr = inst_ptr->MakeInvokerPointer(); auto invoker_ptr = inst_ptr->MakeInvokerPointer();
...@@ -191,8 +194,9 @@ void profile_groupnorm_impl(int do_verification, ...@@ -191,8 +194,9 @@ void profile_groupnorm_impl(int do_verification,
float gb_per_sec = num_bytes / 1.E6 / avg_time; float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " if(time_kernel)
<< inst_ptr->GetTypeString() << std::endl; std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, "
<< inst_ptr->GetTypeString() << std::endl;
if(avg_time < best_avg_time) if(avg_time < best_avg_time)
{ {
...@@ -219,18 +223,30 @@ void profile_groupnorm_impl(int do_verification, ...@@ -219,18 +223,30 @@ void profile_groupnorm_impl(int do_verification,
{ {
std::cout << inst_ptr->GetTypeString() << " failed verification: "; std::cout << inst_ptr->GetTypeString() << " failed verification: ";
LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl; LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl;
return; return false;
} }
else else
{ {
std::cout << "pass" << std::endl; if(time_kernel)
std::cout << "pass" << std::endl;
} }
} }
} }
LogRange(std::cout << "length = ", length, ",") << ", "; if(time_kernel)
std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, " {
<< best_instance_name << std::endl; LogRange(std::cout << "length = ", length, ",") << ", ";
std::cout << "num_kernel = " << num_kernel << ", best perf = " << best_avg_time << " ms, "
<< best_gb_per_sec << " GB/s, " << best_instance_name << std::endl;
}
if(num_kernel == 0)
{
std::cout << "Error: No kernel is tested" << std::endl;
return false;
}
return true;
} }
} // namespace profiler } // namespace profiler
......
...@@ -52,7 +52,7 @@ void print_help_groupnorm() ...@@ -52,7 +52,7 @@ void print_help_groupnorm()
<< "arg5: print tensor value (0: no; 1: yes)\n" << "arg5: print tensor value (0: no; 1: yes)\n"
<< "arg6: time kernel (0=n0, 1=yes)\n" << "arg6: time kernel (0=n0, 1=yes)\n"
<< "arg7: out elementwise op (0=passthrough, 1=sigmoid)\n" << "arg7: out elementwise op (0=passthrough, 1=sigmoid)\n"
<< "--length: tensor extents (e.g, --length 1024 1024) \n" << "--length: tensor extents (e.g, --length 1 16 16 32 40) \n"
<< std::endl; << std::endl;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment