Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0b41ca2d
Commit
0b41ca2d
authored
Apr 02, 2019
by
Chao Liu
Browse files
puting gridwise convolution into its own class
parent
bdbc0eaa
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
16 deletions
+24
-16
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
+7
-5
src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp
...dwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp
+7
-11
src/include/gridwise_convolution_wrapper.hip.hpp
src/include/gridwise_convolution_wrapper.hip.hpp
+10
-0
No files found.
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
View file @
0b41ca2d
#pragma once
#pragma once
#include <unistd.h>
#include <unistd.h>
#include "device.hpp"
#include "device.hpp"
#include "gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_wrapper.hip.hpp"
#include "gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp"
//#include "gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_implicit_gemm_convolution_2_chwn_cyxk_khwn
(
InDesc
,
void
device_implicit_gemm_convolution_2_chwn_cyxk_khwn
(
InDesc
,
...
@@ -272,7 +273,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -272,7 +273,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
{
{
constexpr
auto
gridwise_conv
=
constexpr
auto
gridwise_conv
=
#if 1
#if 1
g
ridwise
_i
mplicit
_g
emm_
convolution_
2_chwn_cyxk_khwn
G
ridwise
ConvolutionI
mplicit
G
emm_
v
2_chwn_cyxk_khwn
#else
#else
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
#endif
#endif
...
@@ -301,11 +302,12 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -301,11 +302,12 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
,
WeiBlockCopyThreadPerDim1
,
InBlockCopyDataPerRead
,
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
>
()
;
WeiBlockCopyDataPerRead
>
{}
;
float
time
=
launch_kernel
(
gridwise_conv
.
Run
,
float
time
=
launch_kernel
(
run_
gridwise_conv
olution
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
gridwise_conv
,
static_cast
<
T
*>
(
in_chwn_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
in_chwn_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_cyxk_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_cyxk_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_khwn_device_buf
.
GetDeviceBuffer
()));
static_cast
<
T
*>
(
out_khwn_device_buf
.
GetDeviceBuffer
()));
...
...
src/include/gridwise_implicit_gemm_
convolution_
2_chwn_cyxk_khwn.hip.hpp
→
src/include/gridwise_
convolution_
implicit_gemm_
v
2_chwn_cyxk_khwn.hip.hpp
View file @
0b41ca2d
...
@@ -34,10 +34,11 @@ template <index_t GridSize,
...
@@ -34,10 +34,11 @@ template <index_t GridSize,
index_t
WeiBlockCopyThreadPerDim1
,
index_t
WeiBlockCopyThreadPerDim1
,
index_t
InBlockCopyDataPerRead
,
index_t
InBlockCopyDataPerRead
,
index_t
WeiBlockCopyDataPerRead
>
index_t
WeiBlockCopyDataPerRead
>
class
gridwise_i
mplicit
_g
emm_
convolution_
2_chwn_cyxk_khwn
struct
GridwiseConvolutionI
mplicit
G
emm_
v
2_chwn_cyxk_khwn
{
{
public:
__host__
__device__
constexpr
GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
()
{}
__host__
__device__
static
index_t
GetSharedMemorySize
()
__host__
__device__
constexpr
index_t
GetSharedMemoryUsage
()
const
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -46,7 +47,6 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
...
@@ -46,7 +47,6 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
constexpr
auto
in_chwn_global_desc
=
InGlobalDesc
{};
constexpr
auto
in_chwn_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_cyxk_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
wei_cyxk_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_khwn_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
Hi
=
in_chwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Hi
=
in_chwn_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Wi
=
in_chwn_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_chwn_global_desc
.
GetLength
(
I2
);
...
@@ -64,10 +64,6 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
...
@@ -64,10 +64,6 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
constexpr
auto
wei_cyxk_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
wei_cyxk_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
Y
,
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
Sequence
<
CPerBlock
,
Y
,
X
,
KPerBlock
>
{},
Number
<
WeiBlockCopyDataPerRead
>
{});
// tensor view of threadwise output in register
constexpr
auto
out_kb_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
,
BPerThread
>
{});
constexpr
index_t
max_align
=
constexpr
index_t
max_align
=
mod_conv
::
max
(
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
mod_conv
::
max
(
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
...
@@ -81,9 +77,9 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
...
@@ -81,9 +77,9 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
return
(
in_block_element_space
+
wei_block_element_space
)
*
sizeof
(
Float
);
return
(
in_block_element_space
+
wei_block_element_space
)
*
sizeof
(
Float
);
}
}
__
global__
static
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
__
device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
Float
*
const
__restrict__
p_out_global
)
const
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
src/include/gridwise_convolution_wrapper.hip.hpp
0 → 100644
View file @
0b41ca2d
#pragma once
template
<
class
GridwiseConvolution
,
class
T
>
__global__
void
run_gridwise_convolution
(
GridwiseConvolution
,
const
T
*
const
__restrict__
p_in_global
,
const
T
*
const
__restrict__
p_wei_global
,
T
*
const
__restrict__
p_out_global
)
{
GridwiseConvolution
{}.
Run
(
p_in_global
,
p_wei_global
,
p_out_global
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment