Commit a0c0360c authored by rocking's avatar rocking
Browse files

Extract variable in example

parent 3af2e4c9
...@@ -26,15 +26,16 @@ using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clam ...@@ -26,15 +26,16 @@ using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clam
static constexpr ck::index_t NumDimSpatial = 2; static constexpr ck::index_t NumDimSpatial = 2;
static constexpr ck::index_t G = 1; static constexpr ck::index_t G = 1;
static constexpr ck::index_t N = 4; // batch size static constexpr ck::index_t N = 4; // batch size
static constexpr ck::index_t K = 64; // output channel static constexpr ck::index_t K = 64; // output channel
static constexpr ck::index_t C = 192; // input channel static constexpr ck::index_t C = 192; // input channel
static constexpr ck::index_t Y = 3; // filter H static constexpr ck::index_t Y = 3; // filter H
static constexpr ck::index_t X = 3; // filter W static constexpr ck::index_t X = 3; // filter W
static constexpr ck::index_t Hi = 71; // input H static constexpr ck::index_t Hi = 71; // input H
static constexpr ck::index_t Wi = 71; // input W static constexpr ck::index_t Wi = 71; // input W
static constexpr ck::index_t Ho = 36; // output H static constexpr ck::index_t Ho = 36; // output H
static constexpr ck::index_t Wo = 36; // output W static constexpr ck::index_t Wo = 36; // output W
static constexpr float requant_scale = 0.5f; // requantize qAcc to qz
struct SimpleDeviceMem struct SimpleDeviceMem
{ {
...@@ -102,26 +103,27 @@ int main(int argc, char* argv[]) ...@@ -102,26 +103,27 @@ int main(int argc, char* argv[])
for(int i = 0; i < op_ptrs.size(); ++i) for(int i = 0; i < op_ptrs.size(); ++i)
{ {
auto& op_ptr = op_ptrs[i]; auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), auto argument_ptr =
wei.GetDeviceBuffer(), op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
{bias.GetDeviceBuffer()}, wei.GetDeviceBuffer(),
out.GetDeviceBuffer(), {bias.GetDeviceBuffer()},
in_lengths, out.GetDeviceBuffer(),
in_strides, in_lengths,
weight_lengths, in_strides,
weight_strides, weight_lengths,
{bias_lengths}, weight_strides,
{bias_strides}, {bias_lengths},
out_lengths, {bias_strides},
out_strides, out_lengths,
conv_strides, out_strides,
conv_dilations, conv_strides,
in_left_pad, conv_dilations,
in_right_pad, in_left_pad,
PassThrough{}, in_right_pad,
PassThrough{}, PassThrough{},
OutElementOp{0.5f, ActivationOp{}}); PassThrough{},
OutElementOp{requant_scale, ActivationOp{}});
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString(); std::string op_name = op_ptr->GetTypeString();
...@@ -165,25 +167,26 @@ int main(int argc, char* argv[]) ...@@ -165,25 +167,26 @@ int main(int argc, char* argv[])
auto& op_ptr = op_ptrs[best_op_id]; auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl; << std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), auto argument_ptr =
wei.GetDeviceBuffer(), op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
{bias.GetDeviceBuffer()}, wei.GetDeviceBuffer(),
out.GetDeviceBuffer(), {bias.GetDeviceBuffer()},
in_lengths, out.GetDeviceBuffer(),
in_strides, in_lengths,
weight_lengths, in_strides,
weight_strides, weight_lengths,
{bias_lengths}, weight_strides,
{bias_strides}, {bias_lengths},
out_lengths, {bias_strides},
out_strides, out_lengths,
conv_strides, out_strides,
conv_dilations, conv_strides,
in_left_pad, conv_dilations,
in_right_pad, in_left_pad,
PassThrough{}, in_right_pad,
PassThrough{}, PassThrough{},
OutElementOp{0.5f, ActivationOp{}}); PassThrough{},
OutElementOp{requant_scale, ActivationOp{}});
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
......
...@@ -24,15 +24,16 @@ using OutElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Ac ...@@ -24,15 +24,16 @@ using OutElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Ac
static constexpr ck::index_t NumDimSpatial = 2; static constexpr ck::index_t NumDimSpatial = 2;
static constexpr ck::index_t G = 1; static constexpr ck::index_t G = 1;
static constexpr ck::index_t N = 4; // batch size static constexpr ck::index_t N = 4; // batch size
static constexpr ck::index_t K = 64; // output channel static constexpr ck::index_t K = 64; // output channel
static constexpr ck::index_t C = 192; // input channel static constexpr ck::index_t C = 192; // input channel
static constexpr ck::index_t Y = 3; // filter H static constexpr ck::index_t Y = 3; // filter H
static constexpr ck::index_t X = 3; // filter W static constexpr ck::index_t X = 3; // filter W
static constexpr ck::index_t Hi = 71; // input H static constexpr ck::index_t Hi = 71; // input H
static constexpr ck::index_t Wi = 71; // input W static constexpr ck::index_t Wi = 71; // input W
static constexpr ck::index_t Ho = 36; // output H static constexpr ck::index_t Ho = 36; // output H
static constexpr ck::index_t Wo = 36; // output W static constexpr ck::index_t Wo = 36; // output W
static constexpr float requant_scale = 0.5f; // requantize qAcc to qY
struct SimpleDeviceMem struct SimpleDeviceMem
{ {
...@@ -96,26 +97,27 @@ int main(int argc, char* argv[]) ...@@ -96,26 +97,27 @@ int main(int argc, char* argv[])
for(int i = 0; i < op_ptrs.size(); ++i) for(int i = 0; i < op_ptrs.size(); ++i)
{ {
auto& op_ptr = op_ptrs[i]; auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), auto argument_ptr =
wei.GetDeviceBuffer(), op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
{}, wei.GetDeviceBuffer(),
out.GetDeviceBuffer(), {},
in_lengths, out.GetDeviceBuffer(),
in_strides, in_lengths,
weight_lengths, in_strides,
weight_strides, weight_lengths,
{}, weight_strides,
{}, {},
out_lengths, {},
out_strides, out_lengths,
conv_strides, out_strides,
conv_dilations, conv_strides,
in_left_pad, conv_dilations,
in_right_pad, in_left_pad,
PassThrough{}, in_right_pad,
PassThrough{}, PassThrough{},
OutElementOp{0.5f, ActivationOp{}}); PassThrough{},
OutElementOp{requant_scale, ActivationOp{}});
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString(); std::string op_name = op_ptr->GetTypeString();
...@@ -158,25 +160,26 @@ int main(int argc, char* argv[]) ...@@ -158,25 +160,26 @@ int main(int argc, char* argv[])
auto& op_ptr = op_ptrs[best_op_id]; auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl; << std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), auto argument_ptr =
wei.GetDeviceBuffer(), op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
{}, wei.GetDeviceBuffer(),
out.GetDeviceBuffer(), {},
in_lengths, out.GetDeviceBuffer(),
in_strides, in_lengths,
weight_lengths, in_strides,
weight_strides, weight_lengths,
{}, weight_strides,
{}, {},
out_lengths, {},
out_strides, out_lengths,
conv_strides, out_strides,
conv_dilations, conv_strides,
in_left_pad, conv_dilations,
in_right_pad, in_left_pad,
PassThrough{}, in_right_pad,
PassThrough{}, PassThrough{},
OutElementOp{0.5f, ActivationOp{}}); PassThrough{},
OutElementOp{requant_scale, ActivationOp{}});
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
......
...@@ -78,6 +78,7 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -78,6 +78,7 @@ using DeviceGroupedConvNDFwdInstance =
int main() int main()
{ {
const auto out_element_op = OutElementOp{0.5f, ActivationOp{}}; float requant_scale = 0.5f;
const auto out_element_op = OutElementOp{requant_scale, ActivationOp{}};
run_conv2d_fwd_bias_perlayer_quantization_example(out_element_op); run_conv2d_fwd_bias_perlayer_quantization_example(out_element_op);
} }
...@@ -81,6 +81,7 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -81,6 +81,7 @@ using DeviceGroupedConvNDFwdInstance =
int main() int main()
{ {
const auto out_element_op = OutElementOp{0.5f, ActivationOp{}}; float scale_z_inv = 0.5f;
const auto out_element_op = OutElementOp{scale_z_inv, ActivationOp{}};
run_conv2d_fwd_bias_perchannel_quantization_example(out_element_op); run_conv2d_fwd_bias_perchannel_quantization_example(out_element_op);
}; };
...@@ -78,6 +78,8 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -78,6 +78,8 @@ using DeviceGroupedConvNDFwdInstance =
int main() int main()
{ {
const auto out_element_op = OutElementOp{0.5f, 0.5f, ActivationOp{}}; float scale_acc = 0.5f;
float scale_z_inv = 0.5f;
const auto out_element_op = OutElementOp{scale_z_inv, scale_acc, ActivationOp{}};
run_conv2d_fwd_bias_perlayer_quantization_example(out_element_op); run_conv2d_fwd_bias_perlayer_quantization_example(out_element_op);
} }
...@@ -73,6 +73,7 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -73,6 +73,7 @@ using DeviceGroupedConvNDFwdInstance =
int main() int main()
{ {
const auto out_element_op = OutElementOp{0.5f, ActivationOp{}}; float requant_scale = 0.5f;
const auto out_element_op = OutElementOp{requant_scale, ActivationOp{}};
run_conv2d_fwd_perlayer_quantization_example(out_element_op); run_conv2d_fwd_perlayer_quantization_example(out_element_op);
} }
...@@ -82,6 +82,7 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -82,6 +82,7 @@ using DeviceGroupedConvNDFwdInstance =
int main() int main()
{ {
const auto out_element_op = OutElementOp{0.5f, ActivationOp{}}; float requant_scale = 0.5f;
const auto out_element_op = OutElementOp{requant_scale, ActivationOp{}};
run_conv2d_fwd_bias_perlayer_quantization_example(out_element_op); run_conv2d_fwd_bias_perlayer_quantization_example(out_element_op);
} }
...@@ -77,6 +77,7 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -77,6 +77,7 @@ using DeviceGroupedConvNDFwdInstance =
int main() int main()
{ {
const auto out_element_op = OutElementOp{0.5f, ActivationOp{}}; float requant_scale = 0.5f;
const auto out_element_op = OutElementOp{requant_scale, ActivationOp{}};
run_conv2d_fwd_perlayer_quantization_example(out_element_op); run_conv2d_fwd_perlayer_quantization_example(out_element_op);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment