Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a240cebb
Commit
a240cebb
authored
Apr 05, 2021
by
Jing Zhang
Browse files
tweak
parent
1098ced2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
48 additions
and
18 deletions
+48
-18
composable_kernel/include/utility/amd_buffer_addressing_v2.hpp
...sable_kernel/include/utility/amd_buffer_addressing_v2.hpp
+38
-9
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
+1
-1
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+9
-8
No files found.
composable_kernel/include/utility/amd_buffer_addressing_v2.hpp
View file @
a240cebb
...
...
@@ -182,15 +182,16 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
)
{
static_assert
((
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
half2_t
>::
value
&&
(
N
==
1
))
||
(
is_same
<
T
,
half4_t
>::
value
&&
(
N
==
1
))
||
(
is_same
<
T
,
half8_t
>::
value
&&
(
N
==
1
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32x2_t
>::
value
&&
(
N
==
1
))
||
(
is_same
<
T
,
int32x4_t
>::
value
&&
(
N
==
1
)),
"wrong! not implemented"
);
static_assert
(
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
half2_t
>::
value
&&
(
N
==
1
))
||
(
is_same
<
T
,
half4_t
>::
value
&&
(
N
==
1
))
||
(
is_same
<
T
,
half8_t
>::
value
&&
(
N
==
1
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32x2_t
>::
value
&&
(
N
==
1
))
||
(
is_same
<
T
,
int32x4_t
>::
value
&&
(
N
==
1
)),
"wrong! not implemented"
);
if
constexpr
(
is_same
<
T
,
float
>::
value
)
{
...
...
@@ -277,6 +278,34 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
return
tmp
.
Vector
();
}
}
else
if
constexpr
(
is_same
<
T
,
int8_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
__llvm_amdgcn_raw_buffer_load_i8
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
2
)
{
return
__llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
4
)
{
return
__llvm_amdgcn_raw_buffer_load_i32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
8
)
{
return
__llvm_amdgcn_raw_buffer_load_i32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
16
)
{
return
__llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
}
}
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
{
if
constexpr
(
N
==
1
)
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
View file @
a240cebb
...
...
@@ -145,7 +145,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr
index_t
BThreadTransferSrcScalarPerVector_W
=
1
;
constexpr
index_t
CThreadTransferDstScalarPerVector_W
=
1
;
constexpr
index_t
CThreadTransferDstScalarPerVector_W
=
K
1
;
static_assert
(
KPerThread
%
CThreadTransferDstScalarPerVector_W
==
0
,
""
);
#else
...
...
driver/src/conv_driver.cpp
View file @
a240cebb
...
...
@@ -64,7 +64,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
0
#elif
1
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
16
;
constexpr
index_t
HI
=
1080
;
...
...
@@ -73,12 +73,13 @@ int main(int argc, char* argv[])
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
// using ConvDilations = Sequence<1, 1>;
using
ConvDilations
=
Sequence
<
2
,
2
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif
1
#elif
0
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
4
;
constexpr
index_t
HI
=
64
;
...
...
@@ -88,7 +89,7 @@ int main(int argc, char* argv[])
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
2
,
2
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
...
...
@@ -630,14 +631,14 @@ int main(int argc, char* argv[])
print_array
(
"ConvStrides"
,
to_multi_index
(
ConvStrides
{}));
print_array
(
"ConvDilations"
,
to_multi_index
(
ConvDilations
{}));
#if
1
#if
0
using in_data_t = float;
constexpr index_t in_vector_size = 1;
using acc_data_t = float;
using out_data_t = float;
#elif
0
using
in_data_t
=
half_t
;
constexpr
index_t
in_vector_size
=
16
;
constexpr
index_t
in_vector_size
=
4
;
using
acc_data_t
=
float
;
using
out_data_t
=
half_t
;
#elif 0
...
...
@@ -799,7 +800,7 @@ int main(int argc, char* argv[])
check_error
(
out_nkhw_host
,
out_nkhw_device
);
#if
1
#if
0
if(do_log)
{
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment