Commit d4881b8a authored by Jehandad Khan's avatar Jehandad Khan
Browse files

add IsSupportedArgument method

parent f785032d
...@@ -89,7 +89,7 @@ void profile_conv_fwd_impl(int do_verification, ...@@ -89,7 +89,7 @@ void profile_conv_fwd_impl(int do_verification,
auto invoker_ptr = conv_ptr.MakeInvokerPointer(); auto invoker_ptr = conv_ptr.MakeInvokerPointer();
//if(conv_ptr.IsSupportedArgument(argument_ptr.get())) if(conv_ptr.IsSupportedArgument(argument_ptr.get()))
{ {
std::string conv_name = conv_ptr.GetTypeString(); std::string conv_name = conv_ptr.GetTypeString();
......
...@@ -31,6 +31,7 @@ struct DeviceConvFwdPtr_t ...@@ -31,6 +31,7 @@ struct DeviceConvFwdPtr_t
std::vector<ck::index_t> input_right_pads); // in,wei and out element ops are ignored for now since even if we change them, they cant be linked std::vector<ck::index_t> input_right_pads); // in,wei and out element ops are ignored for now since even if we change them, they cant be linked
std::unique_ptr<BaseInvoker> MakeInvokerPointer(); // requires including BaseInvoker headers std::unique_ptr<BaseInvoker> MakeInvokerPointer(); // requires including BaseInvoker headers
std::string GetTypeString(); std::string GetTypeString();
bool IsSupportedArgument(const BaseArgument* arg_ptr);
}; };
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(std::vector<DeviceConvFwdPtr_t>& instances); void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(std::vector<DeviceConvFwdPtr_t>& instances);
...@@ -136,6 +136,10 @@ struct DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl ...@@ -136,6 +136,10 @@ struct DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl
{ {
return el->GetTypeString(); return el->GetTypeString();
} }
bool IsSupportedArgument(const DeviceConvFwdPtr_t::BaseArgument* arg)
{
return el->IsSupportedArgument(arg);
}
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough> el; ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough> el;
}; };
...@@ -169,10 +173,13 @@ std::string DeviceConvFwdPtr_t::GetTypeString() ...@@ -169,10 +173,13 @@ std::string DeviceConvFwdPtr_t::GetTypeString()
{ {
return pImpl->GetTypeString(); return pImpl->GetTypeString();
} }
bool DeviceConvFwdPtr_t::IsSupportedArgument(const DeviceConvFwdPtr_t::BaseArgument* arg_ptr)
{
return pImpl->IsSupportedArgument(arg_ptr);
}
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(std::vector<DeviceConvFwdPtr_t>& instances) void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(std::vector<DeviceConvFwdPtr_t>& instances)
{ {
std::ignore = instances;
using namespace ck::tensor_operation::device::device_conv2d_fwd_instance; using namespace ck::tensor_operation::device::device_conv2d_fwd_instance;
std::vector<ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>> local_instances; std::vector<ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>> local_instances;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(local_instances); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(local_instances);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment