Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b3a4d179
Commit
b3a4d179
authored
May 16, 2021
by
Jing Zhang
Browse files
fixed output
parent
9bdad55b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
10 deletions
+12
-10
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+2
-1
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
...include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
+9
-8
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+1
-1
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
b3a4d179
...
@@ -113,7 +113,8 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
...
@@ -113,7 +113,8 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
const
auto
out_m0_m1_m2_n_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
out_m0_m1_m2_n_global_desc
=
transform_dynamic_tensor_descriptor
(
out_gemmm_gemmn_global_desc
,
out_gemmm_gemmn_global_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
4
,
2
,
4
)),
make_pass_through_transform
(
N
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmM
/
8
,
2
,
4
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}));
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
View file @
b3a4d179
...
@@ -476,11 +476,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -476,11 +476,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// output: register to global memory
// output: register to global memory
{
{
StaticBuffer
<
AddressSpace
::
Vgpr
,
float
,
64
>
c_thread_buf_
;
static_for
<
0
,
64
,
1
>
{}(
[
&
](
auto
i
)
{
c_thread_buf_
(
i
)
=
c_thread_buf
.
template
AsType
<
float
>()[
i
];
});
constexpr
auto
OutputLayout
=
blockwise_gemm
.
GetOutputLayout
();
constexpr
auto
OutputLayout
=
blockwise_gemm
.
GetOutputLayout
();
constexpr
index_t
K0
=
OutputLayout
.
M1
();
constexpr
index_t
K0
=
OutputLayout
.
M1
();
constexpr
index_t
K1
=
OutputLayout
.
N1
();
constexpr
index_t
K1
=
OutputLayout
.
N1
();
...
@@ -498,8 +493,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -498,8 +493,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
static_assert
(
BlkSize
==
16
&&
NumBlks
==
4
,
""
);
static_assert
(
BlkSize
==
16
&&
NumBlks
==
4
,
""
);
// force unrolling the output loop to get ride of scratches
// force unrolling the output loop to get ride of scratches
for
(
index_t
i
=
0
;
i
<
NumBlks
;
++
i
)
static_for
<
0
,
NumBlks
,
1
>
{}([
&
](
auto
i
)
{
{
StaticBuffer
<
AddressSpace
::
Vgpr
,
float
,
BlkSize
>
c_thread_buf_
;
static_for
<
0
,
BlkSize
,
1
>
{}([
&
](
auto
j
)
{
c_thread_buf_
(
j
)
=
c_thread_buf
.
template
AsType
<
float
>()[
Number
<
i
*
BlkSize
+
j
>
{}];
});
// calculate origin of thread output tensor on global memory
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
const
auto
c_thread_mtx_on_block
=
...
@@ -535,7 +536,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -535,7 +536,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
c_m0_m1_m2_n_global_desc
,
c_m0_m1_m2_n_global_desc
,
c_global_buf
,
c_global_buf
,
c_m0_m1_n0_n1_global_tensor_iterator_hacks
);
c_m0_m1_n0_n1_global_tensor_iterator_hacks
);
}
}
);
}
}
}
}
...
...
driver/src/conv_driver.cpp
View file @
b3a4d179
...
@@ -688,7 +688,7 @@ int main(int argc, char* argv[])
...
@@ -688,7 +688,7 @@ int main(int argc, char* argv[])
#elif
0
#elif
0
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
#elif
1
#elif
0
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
#elif 1
#elif 1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment