Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
bae8112d
Commit
bae8112d
authored
Mar 09, 2024
by
Jing Zhang
Browse files
enable fwd conv on navi4x
parent
255fbc56
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
13 additions
and
12 deletions
+13
-12
example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt
example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt
+1
-1
example/30_grouped_conv_fwd_multiple_d/common.hpp
example/30_grouped_conv_fwd_multiple_d/common.hpp
+4
-4
example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp
example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+2
-1
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+1
-1
No files found.
example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt
View file @
bae8112d
list
(
APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950
)
list
(
APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950
)
list
(
APPEND gpu_list2 gfx1100 gfx1101 gfx1102
)
list
(
APPEND gpu_list2 gfx1100 gfx1101 gfx1102
gfx1200
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
...
...
example/30_grouped_conv_fwd_multiple_d/common.hpp
View file @
bae8112d
...
@@ -90,10 +90,10 @@ struct ExecutionConfig final
...
@@ -90,10 +90,10 @@ struct ExecutionConfig final
bool
time_kernel
=
true
;
bool
time_kernel
=
true
;
};
};
#define DefaultConvParam
\
#define DefaultConvParam \
ck::utils::conv::ConvParam
\
ck::utils::conv::ConvParam \
{
\
{ \
2, 32, 2,
256, 19
2, {3, 3}, {
71, 71
}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \
2, 32, 2,
32, 3
2, {3, 3}, {
14, 14
}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \
}
}
inline
void
print_help_msg
()
inline
void
print_help_msg
()
...
...
example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp
View file @
bae8112d
...
@@ -90,10 +90,10 @@ struct ExecutionConfig final
...
@@ -90,10 +90,10 @@ struct ExecutionConfig final
bool
time_kernel
=
true
;
bool
time_kernel
=
true
;
};
};
#define DefaultConvParam
\
#define DefaultConvParam \
ck::utils::conv::ConvParam
\
ck::utils::conv::ConvParam \
{
\
{ \
2, 32, 2,
256, 19
2, {3, 3}, {
71, 71
}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \
2, 32, 2,
32, 3
2, {3, 3}, {
14, 14
}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \
}
}
inline
void
print_help_msg
()
inline
void
print_help_msg
()
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
bae8112d
...
@@ -581,7 +581,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -581,7 +581,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
namespace
ctc
=
tensor_layout
::
convolution
;
namespace
ctc
=
tensor_layout
::
convolution
;
// check device
// check device
if
(
ck
::
is_navi3_supported
())
if
(
ck
::
is_navi3_supported
()
||
ck
::
is_navi4_supported
()
)
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
bae8112d
...
@@ -340,7 +340,8 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -340,7 +340,8 @@ struct GridwiseGemmMultipleD_Wmma
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
WmmaK
=
K1
==
16
?
32
:
16
;
// static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
static
constexpr
auto
WmmaK
=
16
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
...
library/include/ck/library/utility/check_err.hpp
View file @
bae8112d
...
@@ -156,7 +156,7 @@ check_err(const Range& out,
...
@@ -156,7 +156,7 @@ check_err(const Range& out,
{
{
max_err
=
err
>
max_err
?
err
:
max_err
;
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
err_count
++
;
//
if(err_count < 5)
if
(
err_count
<
5
)
{
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment