Commit be666c79 authored by Jing Zhang's avatar Jing Zhang
Browse files

tuned

parent e61489e2
...@@ -157,7 +157,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con ...@@ -157,7 +157,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
if(config.time_kernel) if(config.time_kernel)
{ {
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
......
...@@ -30,7 +30,7 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -30,7 +30,7 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F8; using ADataType = F16;
using BDataType = F8; using BDataType = F8;
using AccDataType = F32; using AccDataType = F32;
using CDataType = F16; using CDataType = F16;
......
...@@ -367,11 +367,12 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x) ...@@ -367,11 +367,12 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
// use native conversion to float and convert to fp16 // use native conversion to float and convert to fp16
// return type_convert<half_t>(type_convert<float>(x)); // return type_convert<half_t>(type_convert<float>(x));
// return static_cast<half_t>(x); // return static_cast<half_t>(x);
return static_cast<half_t>(bit_cast<int8_t>(x)); return bit_cast<half_t>(bit_cast<int8_t>(x));
#else #else
// constexpr bool negative_zero_nan = true; // constexpr bool negative_zero_nan = true;
// return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x); // return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
return static_cast<half_t>(bit_cast<int8_t>(x)); uint16_t tmp = bit_cast<uint8_t>(x);
return bit_cast<half_t>(tmp);
#endif #endif
} }
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <iomanip> #include <iomanip>
#include <iostream> #include <iostream>
#include <typeinfo> #include <typeinfo>
#include <unistd.h>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
...@@ -134,11 +135,10 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -134,11 +135,10 @@ bool profile_gemm_splitk_impl(int do_verification,
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
} }
std::string best_op_name;
float best_ave_time = 0;
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; int best_instance_id = 0;
float best_kbatch = 0; float best_kbatch = 0;
int instance_id = 0;
// profile device GEMM instances // profile device GEMM instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
...@@ -200,8 +200,8 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -200,8 +200,8 @@ bool profile_gemm_splitk_impl(int do_verification,
std::string op_name = op_ptr->GetTypeString(); std::string op_name = op_ptr->GetTypeString();
float ave_time = float ave_time = invoker_ptr->Run(argument_ptr.get(),
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); StreamConfig{nullptr, time_kernel, 0, 5, 20});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
...@@ -237,11 +237,9 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -237,11 +237,9 @@ bool profile_gemm_splitk_impl(int do_verification,
if(tflops > best_tflops) if(tflops > best_tflops)
{ {
best_op_name = op_name;
best_tflops = tflops; best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr; best_kbatch = kbatch_curr;
best_instance_id = instance_id;
} }
} }
else else
...@@ -250,8 +248,70 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -250,8 +248,70 @@ bool profile_gemm_splitk_impl(int do_verification,
<< std::endl; << std::endl;
} }
} }
instance_id++;
}
sleep(2);
{
auto& op_ptr = op_ptrs[best_instance_id];
auto argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
best_kbatch);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
if(do_verification)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host : ", c_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
<< std::endl;
}
} }
std::string op_name = op_ptr->GetTypeString();
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, 10, 50});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
if constexpr(is_same<CDataType, float>::value) if constexpr(is_same<CDataType, float>::value)
{ {
std::cout << "Best Perf for datatype = f32"; std::cout << "Best Perf for datatype = f32";
...@@ -288,9 +348,11 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -288,9 +348,11 @@ bool profile_gemm_splitk_impl(int do_verification,
} }
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch << " StrideB = " << StrideB << " StrideC = " << StrideC
<< " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " KBatch = " << best_kbatch << " : " << ave_time << " ms, " << tflops
<< " GB/s, " << best_op_name << std::endl; << " TFlops, " << gb_per_sec << " GB/s, " << op_name << std::endl;
}
}
return pass; return pass;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment