Commit 284178d3 authored by Jehandad Khan's avatar Jehandad Khan
Browse files

Code refactor and add all data types for conv fwd

parent f4965d63
...@@ -7,31 +7,6 @@ ...@@ -7,31 +7,6 @@
#include <vector> #include <vector>
#include "client_app_impl.hpp" #include "client_app_impl.hpp"
enum ConvDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
};
enum ConvInputLayout
{
NCHW, // 0
NHWC, // 1
};
enum ConvWeightLayout
{
KCYX, // 0
KYXC, // 1
};
enum ConvOutputLayout
{
NKHW, // 0
NHWK, // 1
};
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -51,7 +26,7 @@ int main(int argc, char* argv[]) ...@@ -51,7 +26,7 @@ int main(int argc, char* argv[])
exit(1); exit(1);
} }
const int data_type = static_cast<ConvDataType>(std::stoi(argv[2])); const ConvDataType data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const int in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3])); const int in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const int wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4])); const int wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const int out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5])); const int out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
...@@ -88,6 +63,7 @@ int main(int argc, char* argv[]) ...@@ -88,6 +63,7 @@ int main(int argc, char* argv[])
init_method, init_method,
do_log, do_log,
nrepeat, nrepeat,
data_type,
N, N,
K, K,
C, C,
......
...@@ -2,6 +2,31 @@ ...@@ -2,6 +2,31 @@
#include "host_interface.hpp" #include "host_interface.hpp"
enum ConvDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
};
enum ConvInputLayout
{
NCHW, // 0
NHWC, // 1
};
enum ConvWeightLayout
{
KCYX, // 0
KYXC, // 1
};
enum ConvOutputLayout
{
NKHW, // 0
NHWK, // 1
};
namespace ck { namespace ck {
...@@ -43,6 +68,7 @@ void profile_conv_fwd_impl(int do_verification, ...@@ -43,6 +68,7 @@ void profile_conv_fwd_impl(int do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
int nrepeat, int nrepeat,
ConvDataType data_type,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
...@@ -63,9 +89,9 @@ void profile_conv_fwd_impl(int do_verification, ...@@ -63,9 +89,9 @@ void profile_conv_fwd_impl(int do_verification,
const ck::index_t Ho = output_spatial_lengths[0]; const ck::index_t Ho = output_spatial_lengths[0];
const ck::index_t Wo = output_spatial_lengths[1]; const ck::index_t Wo = output_spatial_lengths[1];
const auto in_sz = 1000; const auto in_sz = N * C * Hi * Wi;
const auto wei_sz = 1000; const auto wei_sz = K * C * Y * X;
const auto out_sz = 1000; const auto out_sz = N * K * Ho * Wo;
using WeiDataType = float; using WeiDataType = float;
using InDataType = float; using InDataType = float;
...@@ -79,8 +105,19 @@ void profile_conv_fwd_impl(int do_verification, ...@@ -79,8 +105,19 @@ void profile_conv_fwd_impl(int do_verification,
// add device Conv instances // add device Conv instances
std::vector<DeviceConvFwdPtr_t> conv_ptrs; std::vector<DeviceConvFwdPtr_t> conv_ptrs;
if(data_type == F16_F16_F16)
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(conv_ptrs); {
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(conv_ptrs);
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t(conv_ptrs);
}
else if(data_type == BF16_BF16_BF16)
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(conv_ptrs);
else if(data_type == F32_F32_F32)
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(conv_ptrs);
else if(data_type == INT8_INT8_INT8)
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t(conv_ptrs);
else
throw std::runtime_error("wrong! Invalid data type");
if(conv_ptrs.empty()) if(conv_ptrs.empty())
{ {
throw std::runtime_error("wrong! no device Conv instance found"); throw std::runtime_error("wrong! no device Conv instance found");
......
...@@ -35,3 +35,7 @@ struct DeviceConvFwdPtr_t ...@@ -35,3 +35,7 @@ struct DeviceConvFwdPtr_t
}; };
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(std::vector<DeviceConvFwdPtr_t>& instances); void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t(std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t(std::vector<DeviceConvFwdPtr_t>& instances);
...@@ -43,6 +43,7 @@ add_library(device_operations STATIC ...@@ -43,6 +43,7 @@ add_library(device_operations STATIC
$<TARGET_OBJECTS:device_gemm_bias_relu_add_instance> $<TARGET_OBJECTS:device_gemm_bias_relu_add_instance>
$<TARGET_OBJECTS:device_gemm_bias2d_instance> $<TARGET_OBJECTS:device_gemm_bias2d_instance>
$<TARGET_OBJECTS:device_reduce_instance> $<TARGET_OBJECTS:device_reduce_instance>
device_conv2d.cpp
) )
add_library(composablekernels::device_operations ALIAS device_operations) add_library(composablekernels::device_operations ALIAS device_operations)
...@@ -77,7 +78,9 @@ target_include_directories(device_operations PUBLIC ...@@ -77,7 +78,9 @@ target_include_directories(device_operations PUBLIC
# and pass down here to be exported # and pass down here to be exported
target_compile_definitions(device_operations target_compile_definitions(device_operations
PUBLIC -DCK_AMD_GPU_GFX908 PUBLIC -DCK_AMD_GPU_GFX908)
target_compile_options(device_operations
PRIVATE -amdgpu-target=gfx908 PRIVATE -amdgpu-target=gfx908
PRIVATE -O3 PRIVATE -O3
) )
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "device_operation_instance.hpp" #include "device_operation_instance.hpp"
#include "host_interface.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -107,85 +106,3 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( ...@@ -107,85 +106,3 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
struct DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl
{
std::unique_ptr<DeviceConvFwdPtr_t::BaseArgument> MakeArgumentPointer(void* in_ptr, void* wei_ptr, void* out_ptr,
size_t N, size_t K, size_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
{
return el->MakeArgumentPointer(in_ptr, wei_ptr, out_ptr, N, K, C, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths, conv_filter_strides,
conv_filter_dilations, input_left_pads, input_right_pads, PassThrough{}, PassThrough{}, PassThrough{});
}
std::unique_ptr<DeviceConvFwdPtr_t::BaseInvoker> MakeInvokerPointer()
{
return el->MakeInvokerPointer();
}
std::string GetTypeString()
{
return el->GetTypeString();
}
bool IsSupportedArgument(const DeviceConvFwdPtr_t::BaseArgument* arg)
{
return el->IsSupportedArgument(arg);
}
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough> el;
};
DeviceConvFwdPtr_t::DeviceConvFwdPtr_t() : pImpl(nullptr){}
// DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl& impl) : pImpl(std::make_unique<DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl>(impl)) {}
DeviceConvFwdPtr_t::~DeviceConvFwdPtr_t() = default;
DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t&&) = default;
DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl& other) : pImpl(std::make_unique<DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl>(std::move(other))){}
std::unique_ptr<DeviceConvFwdPtr_t::BaseArgument> DeviceConvFwdPtr_t::MakeArgumentPointer(void* in_ptr, void* wei_ptr, void* out_ptr,
size_t N, size_t K, size_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
{
return pImpl->MakeArgumentPointer(in_ptr, wei_ptr, out_ptr, N, K, C, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths, conv_filter_strides,
conv_filter_dilations, input_left_pads, input_right_pads);
}
std::unique_ptr<DeviceConvFwdPtr_t::BaseInvoker> DeviceConvFwdPtr_t::MakeInvokerPointer()
{
return pImpl->MakeInvokerPointer();
}
std::string DeviceConvFwdPtr_t::GetTypeString()
{
return pImpl->GetTypeString();
}
bool DeviceConvFwdPtr_t::IsSupportedArgument(const DeviceConvFwdPtr_t::BaseArgument* arg_ptr)
{
return pImpl->IsSupportedArgument(arg_ptr);
}
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(std::vector<DeviceConvFwdPtr_t>& instances)
{
using namespace ck::tensor_operation::device::device_conv2d_fwd_instance;
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>> local_instances;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(local_instances);
// convert local_instances to instances
for(auto& kinder: local_instances)
{
DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)};
instances.emplace_back(tmp); // Perhaps we can do better
}
return;
}
#include <stdlib.h>
#include "config.hpp"
#include "device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
#include "host_interface.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_conv2d_fwd_instance {
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
} // namespace device_conv2d_fwd_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
struct DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl
{
std::unique_ptr<DeviceConvFwdPtr_t::BaseArgument> MakeArgumentPointer(void* in_ptr, void* wei_ptr, void* out_ptr,
size_t N, size_t K, size_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
{
return el->MakeArgumentPointer(in_ptr, wei_ptr, out_ptr, N, K, C, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths, conv_filter_strides,
conv_filter_dilations, input_left_pads, input_right_pads, PassThrough{}, PassThrough{}, PassThrough{});
}
std::unique_ptr<DeviceConvFwdPtr_t::BaseInvoker> MakeInvokerPointer()
{
return el->MakeInvokerPointer();
}
std::string GetTypeString()
{
return el->GetTypeString();
}
bool IsSupportedArgument(const DeviceConvFwdPtr_t::BaseArgument* arg)
{
return el->IsSupportedArgument(arg);
}
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough> el;
};
DeviceConvFwdPtr_t::DeviceConvFwdPtr_t() : pImpl(nullptr){}
// DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl& impl) : pImpl(std::make_unique<DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl>(impl)) {}
DeviceConvFwdPtr_t::~DeviceConvFwdPtr_t() = default;
DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t&&) = default;
DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl& other) : pImpl(std::make_unique<DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl>(std::move(other))){}
std::unique_ptr<DeviceConvFwdPtr_t::BaseArgument> DeviceConvFwdPtr_t::MakeArgumentPointer(void* in_ptr, void* wei_ptr, void* out_ptr,
size_t N, size_t K, size_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
{
return pImpl->MakeArgumentPointer(in_ptr, wei_ptr, out_ptr, N, K, C, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths, conv_filter_strides,
conv_filter_dilations, input_left_pads, input_right_pads);
}
std::unique_ptr<DeviceConvFwdPtr_t::BaseInvoker> DeviceConvFwdPtr_t::MakeInvokerPointer()
{
return pImpl->MakeInvokerPointer();
}
std::string DeviceConvFwdPtr_t::GetTypeString()
{
return pImpl->GetTypeString();
}
bool DeviceConvFwdPtr_t::IsSupportedArgument(const DeviceConvFwdPtr_t::BaseArgument* arg_ptr)
{
return pImpl->IsSupportedArgument(arg_ptr);
}
using namespace ck::tensor_operation::device::device_conv2d_fwd_instance;
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(std::vector<DeviceConvFwdPtr_t>& instances)
{
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>> local_instances;
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(local_instances);
for(auto& kinder: local_instances)
{
DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)};
instances.emplace_back(tmp);
}
return;
}
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(std::vector<DeviceConvFwdPtr_t>& instances)
{
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>> local_instances;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(local_instances);
for(auto& kinder: local_instances)
{
DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)};
instances.emplace_back(tmp); // Perhaps we can do better
}
return;
}
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(std::vector<DeviceConvFwdPtr_t>& instances)
{
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>> local_instances;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(local_instances);
for(auto& kinder: local_instances)
{
DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)};
instances.emplace_back(tmp); // Perhaps we can do better
}
return;
}
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t(std::vector<DeviceConvFwdPtr_t>& instances)
{
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>> local_instances;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(local_instances);
for(auto& kinder: local_instances)
{
DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)};
instances.emplace_back(tmp); // Perhaps we can do better
}
return;
}
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t(std::vector<DeviceConvFwdPtr_t>& instances)
{
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>> local_instances;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(local_instances);
for(auto& kinder: local_instances)
{
DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)};
instances.emplace_back(tmp);
}
return;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment