Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d4881b8a
"vscode:/vscode.git/clone" did not exist on "31336dae3b95a60069fc08ed4024b14f9b71fc11"
Commit
d4881b8a
authored
Mar 16, 2022
by
Jehandad Khan
Browse files
add IsSupportedArgument method
parent
f785032d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
2 deletions
+10
-2
example/14_client_app/client_app_impl.hpp
example/14_client_app/client_app_impl.hpp
+1
-1
library/include/ck/library/host/host_interface.hpp
library/include/ck/library/host/host_interface.hpp
+1
-0
library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance2.cpp
...wd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance2.cpp
+8
-1
No files found.
example/14_client_app/client_app_impl.hpp
View file @
d4881b8a
...
@@ -89,7 +89,7 @@ void profile_conv_fwd_impl(int do_verification,
...
@@ -89,7 +89,7 @@ void profile_conv_fwd_impl(int do_verification,
auto
invoker_ptr
=
conv_ptr
.
MakeInvokerPointer
();
auto
invoker_ptr
=
conv_ptr
.
MakeInvokerPointer
();
//
if(conv_ptr.IsSupportedArgument(argument_ptr.get()))
if
(
conv_ptr
.
IsSupportedArgument
(
argument_ptr
.
get
()))
{
{
std
::
string
conv_name
=
conv_ptr
.
GetTypeString
();
std
::
string
conv_name
=
conv_ptr
.
GetTypeString
();
...
...
library/include/ck/library/host/host_interface.hpp
View file @
d4881b8a
...
@@ -31,6 +31,7 @@ struct DeviceConvFwdPtr_t
...
@@ -31,6 +31,7 @@ struct DeviceConvFwdPtr_t
std
::
vector
<
ck
::
index_t
>
input_right_pads
);
// in,wei and out element ops are ignored for now since even if we change them, they cant be linked
std
::
vector
<
ck
::
index_t
>
input_right_pads
);
// in,wei and out element ops are ignored for now since even if we change them, they cant be linked
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
();
// requires including BaseInvoker headers
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
();
// requires including BaseInvoker headers
std
::
string
GetTypeString
();
std
::
string
GetTypeString
();
bool
IsSupportedArgument
(
const
BaseArgument
*
arg_ptr
);
};
};
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance2.cpp
View file @
d4881b8a
...
@@ -136,6 +136,10 @@ struct DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl
...
@@ -136,6 +136,10 @@ struct DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl
{
{
return
el
->
GetTypeString
();
return
el
->
GetTypeString
();
}
}
bool
IsSupportedArgument
(
const
DeviceConvFwdPtr_t
::
BaseArgument
*
arg
)
{
return
el
->
IsSupportedArgument
(
arg
);
}
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
el
;
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
el
;
};
};
...
@@ -169,10 +173,13 @@ std::string DeviceConvFwdPtr_t::GetTypeString()
...
@@ -169,10 +173,13 @@ std::string DeviceConvFwdPtr_t::GetTypeString()
{
{
return
pImpl
->
GetTypeString
();
return
pImpl
->
GetTypeString
();
}
}
bool
DeviceConvFwdPtr_t
::
IsSupportedArgument
(
const
DeviceConvFwdPtr_t
::
BaseArgument
*
arg_ptr
)
{
return
pImpl
->
IsSupportedArgument
(
arg_ptr
);
}
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
)
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
)
{
{
std
::
ignore
=
instances
;
using
namespace
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_instance
;
using
namespace
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_instance
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>
local_instances
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>
local_instances
;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances
(
local_instances
);
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances
(
local_instances
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment