Commit cb2b5c86 authored by rocking5566's avatar rocking5566 Committed by rocking
Browse files

Merge branch 'develop' into group_norm

parents 7cda0a07 43c898f6
......@@ -10,4 +10,5 @@ struct StreamConfig
{
hipStream_t stream_id_ = nullptr;
bool time_kernel_ = false;
int log_level_ = 0;
};
......@@ -606,9 +606,11 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if 1
if(stream_config.log_level_ > 0)
{
arg.Print();
#endif
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
......
......@@ -49,7 +49,7 @@ template <typename XDataType,
typename BetaDataType,
typename AccDataType,
typename YDataType>
void profile_groupnorm_impl(int do_verification,
bool profile_groupnorm_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
......@@ -61,7 +61,7 @@ void profile_groupnorm_impl(int do_verification,
using Sigmoid = ck::tensor_operation::element_wise::Sigmoid;
if(length.size() != 5)
return;
return false;
index_t G = length[3];
index_t C = length[4];
......@@ -156,6 +156,8 @@ void profile_groupnorm_impl(int do_verification,
}
}
int num_kernel = 0;
for(auto& inst_ptr : instances)
{
auto argument_ptr = inst_ptr->MakeArgumentPointer(
......@@ -172,12 +174,13 @@ void profile_groupnorm_impl(int do_verification,
y_dev.GetDeviceBuffer(),
Sigmoid{});
if(!inst_ptr->IsSupportedArgument(argument_ptr.get()))
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{
std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: ";
LogRange(std::cout << "input lengths = [", length, "], ") << std::endl;
return;
++num_kernel;
}
else
{
continue;
}
auto invoker_ptr = inst_ptr->MakeInvokerPointer();
......@@ -191,6 +194,7 @@ void profile_groupnorm_impl(int do_verification,
float gb_per_sec = num_bytes / 1.E6 / avg_time;
if(time_kernel)
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, "
<< inst_ptr->GetTypeString() << std::endl;
......@@ -219,18 +223,30 @@ void profile_groupnorm_impl(int do_verification,
{
std::cout << inst_ptr->GetTypeString() << " failed verification: ";
LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl;
return;
return false;
}
else
{
if(time_kernel)
std::cout << "pass" << std::endl;
}
}
}
if(time_kernel)
{
LogRange(std::cout << "length = ", length, ",") << ", ";
std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_instance_name << std::endl;
std::cout << "num_kernel = " << num_kernel << ", best perf = " << best_avg_time << " ms, "
<< best_gb_per_sec << " GB/s, " << best_instance_name << std::endl;
}
if(num_kernel == 0)
{
std::cout << "Error: No kernel is tested" << std::endl;
return false;
}
return true;
}
} // namespace profiler
......
......@@ -52,7 +52,7 @@ void print_help_groupnorm()
<< "arg5: print tensor value (0: no; 1: yes)\n"
<< "arg6: time kernel (0=n0, 1=yes)\n"
<< "arg7: out elementwise op (0=passthrough, 1=sigmoid)\n"
<< "--length: tensor extents (e.g, --length 1024 1024) \n"
<< "--length: tensor extents (e.g, --length 1 16 16 32 40) \n"
<< std::endl;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment