Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
284178d3
Commit
284178d3
authored
Mar 23, 2022
by
Jehandad Khan
Browse files
Code refactor and add all data types for conv fwd
parent
f4965d63
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
209 additions
and
115 deletions
+209
-115
example/14_client_app/client_app.cpp
example/14_client_app/client_app.cpp
+2
-26
example/14_client_app/client_app_impl.hpp
example/14_client_app/client_app_impl.hpp
+42
-5
library/include/ck/library/host/host_interface.hpp
library/include/ck/library/host/host_interface.hpp
+4
-0
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+4
-1
library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp
...fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp
+0
-83
library/src/tensor_operation_instance/gpu/device_conv2d.cpp
library/src/tensor_operation_instance/gpu/device_conv2d.cpp
+157
-0
No files found.
example/14_client_app/client_app.cpp
View file @
284178d3
...
@@ -7,31 +7,6 @@
...
@@ -7,31 +7,6 @@
#include <vector>
#include <vector>
#include "client_app_impl.hpp"
#include "client_app_impl.hpp"
enum
ConvDataType
{
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
BF16_BF16_BF16
,
// 2
INT8_INT8_INT8
,
// 3
};
enum
ConvInputLayout
{
NCHW
,
// 0
NHWC
,
// 1
};
enum
ConvWeightLayout
{
KCYX
,
// 0
KYXC
,
// 1
};
enum
ConvOutputLayout
{
NKHW
,
// 0
NHWK
,
// 1
};
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
...
@@ -51,7 +26,7 @@ int main(int argc, char* argv[])
...
@@ -51,7 +26,7 @@ int main(int argc, char* argv[])
exit
(
1
);
exit
(
1
);
}
}
const
int
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
ConvDataType
data_type
=
static_cast
<
ConvDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
int
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
int
in_layout
=
static_cast
<
ConvInputLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
int
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
int
wei_layout
=
static_cast
<
ConvWeightLayout
>
(
std
::
stoi
(
argv
[
4
]));
const
int
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
const
int
out_layout
=
static_cast
<
ConvOutputLayout
>
(
std
::
stoi
(
argv
[
5
]));
...
@@ -88,6 +63,7 @@ int main(int argc, char* argv[])
...
@@ -88,6 +63,7 @@ int main(int argc, char* argv[])
init_method
,
init_method
,
do_log
,
do_log
,
nrepeat
,
nrepeat
,
data_type
,
N
,
N
,
K
,
K
,
C
,
C
,
...
...
example/14_client_app/client_app_impl.hpp
View file @
284178d3
...
@@ -2,6 +2,31 @@
...
@@ -2,6 +2,31 @@
#include "host_interface.hpp"
#include "host_interface.hpp"
enum
ConvDataType
{
F32_F32_F32
,
// 0
F16_F16_F16
,
// 1
BF16_BF16_BF16
,
// 2
INT8_INT8_INT8
,
// 3
};
enum
ConvInputLayout
{
NCHW
,
// 0
NHWC
,
// 1
};
enum
ConvWeightLayout
{
KCYX
,
// 0
KYXC
,
// 1
};
enum
ConvOutputLayout
{
NKHW
,
// 0
NHWK
,
// 1
};
namespace
ck
{
namespace
ck
{
...
@@ -43,6 +68,7 @@ void profile_conv_fwd_impl(int do_verification,
...
@@ -43,6 +68,7 @@ void profile_conv_fwd_impl(int do_verification,
int
init_method
,
int
init_method
,
bool
do_log
,
bool
do_log
,
int
nrepeat
,
int
nrepeat
,
ConvDataType
data_type
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
...
@@ -63,9 +89,9 @@ void profile_conv_fwd_impl(int do_verification,
...
@@ -63,9 +89,9 @@ void profile_conv_fwd_impl(int do_verification,
const
ck
::
index_t
Ho
=
output_spatial_lengths
[
0
];
const
ck
::
index_t
Ho
=
output_spatial_lengths
[
0
];
const
ck
::
index_t
Wo
=
output_spatial_lengths
[
1
];
const
ck
::
index_t
Wo
=
output_spatial_lengths
[
1
];
const
auto
in_sz
=
1000
;
const
auto
in_sz
=
N
*
C
*
Hi
*
Wi
;
const
auto
wei_sz
=
1000
;
const
auto
wei_sz
=
K
*
C
*
Y
*
X
;
const
auto
out_sz
=
1000
;
const
auto
out_sz
=
N
*
K
*
Ho
*
Wo
;
using
WeiDataType
=
float
;
using
WeiDataType
=
float
;
using
InDataType
=
float
;
using
InDataType
=
float
;
...
@@ -79,8 +105,19 @@ void profile_conv_fwd_impl(int do_verification,
...
@@ -79,8 +105,19 @@ void profile_conv_fwd_impl(int do_verification,
// add device Conv instances
// add device Conv instances
std
::
vector
<
DeviceConvFwdPtr_t
>
conv_ptrs
;
std
::
vector
<
DeviceConvFwdPtr_t
>
conv_ptrs
;
if
(
data_type
==
F16_F16_F16
)
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t
(
conv_ptrs
);
{
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t
(
conv_ptrs
);
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t
(
conv_ptrs
);
}
else
if
(
data_type
==
BF16_BF16_BF16
)
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t
(
conv_ptrs
);
else
if
(
data_type
==
F32_F32_F32
)
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t
(
conv_ptrs
);
else
if
(
data_type
==
INT8_INT8_INT8
)
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t
(
conv_ptrs
);
else
throw
std
::
runtime_error
(
"wrong! Invalid data type"
);
if
(
conv_ptrs
.
empty
())
if
(
conv_ptrs
.
empty
())
{
{
throw
std
::
runtime_error
(
"wrong! no device Conv instance found"
);
throw
std
::
runtime_error
(
"wrong! no device Conv instance found"
);
...
...
library/include/ck/library/host/host_interface.hpp
View file @
284178d3
...
@@ -35,3 +35,7 @@ struct DeviceConvFwdPtr_t
...
@@ -35,3 +35,7 @@ struct DeviceConvFwdPtr_t
};
};
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
void
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
284178d3
...
@@ -43,6 +43,7 @@ add_library(device_operations STATIC
...
@@ -43,6 +43,7 @@ add_library(device_operations STATIC
$<TARGET_OBJECTS:device_gemm_bias_relu_add_instance>
$<TARGET_OBJECTS:device_gemm_bias_relu_add_instance>
$<TARGET_OBJECTS:device_gemm_bias2d_instance>
$<TARGET_OBJECTS:device_gemm_bias2d_instance>
$<TARGET_OBJECTS:device_reduce_instance>
$<TARGET_OBJECTS:device_reduce_instance>
device_conv2d.cpp
)
)
add_library
(
composablekernels::device_operations ALIAS device_operations
)
add_library
(
composablekernels::device_operations ALIAS device_operations
)
...
@@ -77,7 +78,9 @@ target_include_directories(device_operations PUBLIC
...
@@ -77,7 +78,9 @@ target_include_directories(device_operations PUBLIC
# and pass down here to be exported
# and pass down here to be exported
target_compile_definitions
(
device_operations
target_compile_definitions
(
device_operations
PUBLIC -DCK_AMD_GPU_GFX908
PUBLIC -DCK_AMD_GPU_GFX908
)
target_compile_options
(
device_operations
PRIVATE -amdgpu-target=gfx908
PRIVATE -amdgpu-target=gfx908
PRIVATE -O3
PRIVATE -O3
)
)
...
...
library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp
View file @
284178d3
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
#include "device_operation_instance.hpp"
#include "host_interface.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -107,85 +106,3 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(
...
@@ -107,85 +106,3 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
struct
DeviceConvFwdPtr_t
::
DeviceConvFwdPtrImpl
{
std
::
unique_ptr
<
DeviceConvFwdPtr_t
::
BaseArgument
>
MakeArgumentPointer
(
void
*
in_ptr
,
void
*
wei_ptr
,
void
*
out_ptr
,
size_t
N
,
size_t
K
,
size_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
{
return
el
->
MakeArgumentPointer
(
in_ptr
,
wei_ptr
,
out_ptr
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
}
std
::
unique_ptr
<
DeviceConvFwdPtr_t
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
el
->
MakeInvokerPointer
();
}
std
::
string
GetTypeString
()
{
return
el
->
GetTypeString
();
}
bool
IsSupportedArgument
(
const
DeviceConvFwdPtr_t
::
BaseArgument
*
arg
)
{
return
el
->
IsSupportedArgument
(
arg
);
}
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
el
;
};
DeviceConvFwdPtr_t
::
DeviceConvFwdPtr_t
()
:
pImpl
(
nullptr
){}
// DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl& impl) : pImpl(std::make_unique<DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl>(impl)) {}
DeviceConvFwdPtr_t
::~
DeviceConvFwdPtr_t
()
=
default
;
DeviceConvFwdPtr_t
::
DeviceConvFwdPtr_t
(
DeviceConvFwdPtr_t
&&
)
=
default
;
DeviceConvFwdPtr_t
::
DeviceConvFwdPtr_t
(
DeviceConvFwdPtr_t
::
DeviceConvFwdPtrImpl
&
other
)
:
pImpl
(
std
::
make_unique
<
DeviceConvFwdPtr_t
::
DeviceConvFwdPtrImpl
>
(
std
::
move
(
other
))){}
std
::
unique_ptr
<
DeviceConvFwdPtr_t
::
BaseArgument
>
DeviceConvFwdPtr_t
::
MakeArgumentPointer
(
void
*
in_ptr
,
void
*
wei_ptr
,
void
*
out_ptr
,
size_t
N
,
size_t
K
,
size_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
{
return
pImpl
->
MakeArgumentPointer
(
in_ptr
,
wei_ptr
,
out_ptr
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
}
std
::
unique_ptr
<
DeviceConvFwdPtr_t
::
BaseInvoker
>
DeviceConvFwdPtr_t
::
MakeInvokerPointer
()
{
return
pImpl
->
MakeInvokerPointer
();
}
std
::
string
DeviceConvFwdPtr_t
::
GetTypeString
()
{
return
pImpl
->
GetTypeString
();
}
bool
DeviceConvFwdPtr_t
::
IsSupportedArgument
(
const
DeviceConvFwdPtr_t
::
BaseArgument
*
arg_ptr
)
{
return
pImpl
->
IsSupportedArgument
(
arg_ptr
);
}
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
)
{
using
namespace
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_instance
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>
local_instances
;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances
(
local_instances
);
// convert local_instances to instances
for
(
auto
&
kinder
:
local_instances
)
{
DeviceConvFwdPtr_t
::
DeviceConvFwdPtrImpl
tmp
{
std
::
move
(
kinder
)};
instances
.
emplace_back
(
tmp
);
// Perhaps we can do better
}
return
;
}
library/src/tensor_operation_instance/gpu/device_conv2d.cpp
0 → 100644
View file @
284178d3
#include <stdlib.h>
#include "config.hpp"
#include "device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
#include "host_interface.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv2d_fwd_instance
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
void
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
}
// namespace device_conv2d_fwd_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
struct
DeviceConvFwdPtr_t
::
DeviceConvFwdPtrImpl
{
std
::
unique_ptr
<
DeviceConvFwdPtr_t
::
BaseArgument
>
MakeArgumentPointer
(
void
*
in_ptr
,
void
*
wei_ptr
,
void
*
out_ptr
,
size_t
N
,
size_t
K
,
size_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
{
return
el
->
MakeArgumentPointer
(
in_ptr
,
wei_ptr
,
out_ptr
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
}
std
::
unique_ptr
<
DeviceConvFwdPtr_t
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
el
->
MakeInvokerPointer
();
}
std
::
string
GetTypeString
()
{
return
el
->
GetTypeString
();
}
bool
IsSupportedArgument
(
const
DeviceConvFwdPtr_t
::
BaseArgument
*
arg
)
{
return
el
->
IsSupportedArgument
(
arg
);
}
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
el
;
};
DeviceConvFwdPtr_t
::
DeviceConvFwdPtr_t
()
:
pImpl
(
nullptr
){}
// DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl& impl) : pImpl(std::make_unique<DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl>(impl)) {}
DeviceConvFwdPtr_t
::~
DeviceConvFwdPtr_t
()
=
default
;
DeviceConvFwdPtr_t
::
DeviceConvFwdPtr_t
(
DeviceConvFwdPtr_t
&&
)
=
default
;
DeviceConvFwdPtr_t
::
DeviceConvFwdPtr_t
(
DeviceConvFwdPtr_t
::
DeviceConvFwdPtrImpl
&
other
)
:
pImpl
(
std
::
make_unique
<
DeviceConvFwdPtr_t
::
DeviceConvFwdPtrImpl
>
(
std
::
move
(
other
))){}
std
::
unique_ptr
<
DeviceConvFwdPtr_t
::
BaseArgument
>
DeviceConvFwdPtr_t
::
MakeArgumentPointer
(
void
*
in_ptr
,
void
*
wei_ptr
,
void
*
out_ptr
,
size_t
N
,
size_t
K
,
size_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
{
return
pImpl
->
MakeArgumentPointer
(
in_ptr
,
wei_ptr
,
out_ptr
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
}
std
::
unique_ptr
<
DeviceConvFwdPtr_t
::
BaseInvoker
>
DeviceConvFwdPtr_t
::
MakeInvokerPointer
()
{
return
pImpl
->
MakeInvokerPointer
();
}
std
::
string
DeviceConvFwdPtr_t
::
GetTypeString
()
{
return
pImpl
->
GetTypeString
();
}
bool
DeviceConvFwdPtr_t
::
IsSupportedArgument
(
const
DeviceConvFwdPtr_t
::
BaseArgument
*
arg_ptr
)
{
return
pImpl
->
IsSupportedArgument
(
arg_ptr
);
}
using
namespace
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_instance
;
void
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
)
{
std
::
vector
<
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>
local_instances
;
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
(
local_instances
);
for
(
auto
&
kinder
:
local_instances
)
{
DeviceConvFwdPtr_t
::
DeviceConvFwdPtrImpl
tmp
{
std
::
move
(
kinder
)};
instances
.
emplace_back
(
tmp
);
}
return
;
}
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
)
{
std
::
vector
<
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>
local_instances
;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances
(
local_instances
);
for
(
auto
&
kinder
:
local_instances
)
{
DeviceConvFwdPtr_t
::
DeviceConvFwdPtrImpl
tmp
{
std
::
move
(
kinder
)};
instances
.
emplace_back
(
tmp
);
// Perhaps we can do better
}
return
;
}
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
)
{
std
::
vector
<
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>
local_instances
;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances
(
local_instances
);
for
(
auto
&
kinder
:
local_instances
)
{
DeviceConvFwdPtr_t
::
DeviceConvFwdPtrImpl
tmp
{
std
::
move
(
kinder
)};
instances
.
emplace_back
(
tmp
);
// Perhaps we can do better
}
return
;
}
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
)
{
std
::
vector
<
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>
local_instances
;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances
(
local_instances
);
for
(
auto
&
kinder
:
local_instances
)
{
DeviceConvFwdPtr_t
::
DeviceConvFwdPtrImpl
tmp
{
std
::
move
(
kinder
)};
instances
.
emplace_back
(
tmp
);
// Perhaps we can do better
}
return
;
}
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
)
{
std
::
vector
<
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>
local_instances
;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances
(
local_instances
);
for
(
auto
&
kinder
:
local_instances
)
{
DeviceConvFwdPtr_t
::
DeviceConvFwdPtrImpl
tmp
{
std
::
move
(
kinder
)};
instances
.
emplace_back
(
tmp
);
}
return
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment