Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7b1ce567
"include/vscode:/vscode.git/clone" did not exist on "5b57ab96a8208eec1969a3dcadb555a6246ddb95"
Commit
7b1ce567
authored
Dec 08, 2021
by
Jing Zhang
Browse files
clean code
parent
97a5b74a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
21 deletions
+11
-21
device_operation/include/device_gemm_xdl.hpp
device_operation/include/device_gemm_xdl.hpp
+7
-10
host/driver_offline/CMakeLists.txt
host/driver_offline/CMakeLists.txt
+1
-0
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
+3
-11
No files found.
device_operation/include/device_gemm_xdl.hpp
View file @
7b1ce567
...
@@ -80,12 +80,10 @@ struct DeviceGemmXdl
...
@@ -80,12 +80,10 @@ struct DeviceGemmXdl
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
std
::
cout
<<
"PadM = "
<<
PadM
<<
" M = "
<<
M
+
PadM
<<
std
::
endl
;
const
auto
a_grid_desc_k0_m_k1
=
const
auto
a_grid_desc_k0_m_k1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pad_transform
(
M
,
I0
,
PadM
)),
make_
right_
pad_transform
(
M
,
PadM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
@@ -111,12 +109,10 @@ struct DeviceGemmXdl
...
@@ -111,12 +109,10 @@ struct DeviceGemmXdl
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
std
::
cout
<<
"PadN = "
<<
PadN
<<
" N = "
<<
N
+
PadN
<<
std
::
endl
;
const
auto
b_grid_desc_k0_n_k1
=
const
auto
b_grid_desc_k0_n_k1
=
transform_tensor_descriptor
(
b_grid_desc_k_n
,
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pad_transform
(
N
,
I0
,
PadN
)),
make_
right_
pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
@@ -141,7 +137,7 @@ struct DeviceGemmXdl
...
@@ -141,7 +137,7 @@ struct DeviceGemmXdl
const
auto
c_grid_desc_m_n_
=
transform_tensor_descriptor
(
const
auto
c_grid_desc_m_n_
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
c_grid_desc_m_n
,
make_tuple
(
make_pad_transform
(
M
,
I0
,
PadM
),
make_pad_transform
(
N
,
I0
,
PadN
)),
make_tuple
(
make_
right_
pad_transform
(
M
,
PadM
),
make_
right_
pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
@@ -314,9 +310,10 @@ struct DeviceGemmXdl
...
@@ -314,9 +310,10 @@ struct DeviceGemmXdl
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
{
{
{
std
::
cout
<<
"MPerBlock = "
<<
MPerBlock
<<
" NPerBlock = "
<<
NPerBlock
std
::
cout
<<
"BlockGemmShape: {"
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
" MXdlPerWave = "
<<
MXdlPerWave
<<
" NXdlPerWave = "
<<
NXdlPerWave
<<
K0PerBlock
<<
"}, WaveGemmShape: {"
<<
MXdlPerWave
*
MPerXDL
<<
", "
<<
std
::
endl
;
<<
NXdlPerWave
*
NPerXDL
<<
"} XDLGemmShape: {"
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
...
...
host/driver_offline/CMakeLists.txt
View file @
7b1ce567
...
@@ -10,6 +10,7 @@ include_directories(BEFORE
...
@@ -10,6 +10,7 @@ include_directories(BEFORE
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/problem_transform
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/problem_transform
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/driver
${
PROJECT_SOURCE_DIR
}
/composable_kernel/include/driver
${
PROJECT_SOURCE_DIR
}
/external/rocm/include
${
PROJECT_SOURCE_DIR
}
/external/rocm/include
${
PROJECT_SOURCE_DIR
}
/device_operation/include
)
)
set
(
CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp
)
set
(
CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp
)
...
...
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
View file @
7b1ce567
...
@@ -5,15 +5,7 @@
...
@@ -5,15 +5,7 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp"
#include "element_wise_operation.hpp"
struct
OpPassThrough
{
template
<
typename
T
>
__host__
__device__
constexpr
T
operator
()(
T
v
)
const
{
return
v
;
}
};
template
<
ck
::
index_t
BlockSize
,
template
<
ck
::
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAB
,
...
@@ -79,7 +71,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
...
@@ -79,7 +71,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
using
ElementwiseOperation
=
Op
PassThrough
;
using
ElementwiseOperation
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
GridwiseGemm
=
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
BlockSize
,
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
BlockSize
,
...
@@ -166,7 +158,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
...
@@ -166,7 +158,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
element_op_
=
OpPassThrough
{};
auto
element_op_
=
ElementwiseOperation
{};
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
if
(
has_main_k0_block_loop
)
if
(
has_main_k0_block_loop
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment