Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
1f3870ca
Commit
1f3870ca
authored
Jan 23, 2019
by
Chao Liu
Browse files
another version of blockwise 2d tensor copy
parent
e9ac4855
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
243 additions
and
52 deletions
+243
-52
driver/conv.cu
driver/conv.cu
+2
-2
driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
+0
-13
driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh
driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh
+24
-6
src/include/blockwise_2d_tensor_op.cuh
src/include/blockwise_2d_tensor_op.cuh
+183
-6
src/include/common.cuh
src/include/common.cuh
+2
-0
src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
...e/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
+4
-0
src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh
...e/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh
+28
-25
No files found.
driver/conv.cu
View file @
1f3870ca
...
@@ -376,7 +376,7 @@ int main()
...
@@ -376,7 +376,7 @@ int main()
constexpr
unsigned
K
=
64
;
constexpr
unsigned
K
=
64
;
constexpr
unsigned
S
=
3
;
constexpr
unsigned
S
=
3
;
constexpr
unsigned
R
=
3
;
constexpr
unsigned
R
=
3
;
#elif
0
#elif
1
constexpr
unsigned
N
=
64
;
constexpr
unsigned
N
=
64
;
constexpr
unsigned
C
=
256
;
constexpr
unsigned
C
=
256
;
constexpr
unsigned
HI
=
36
;
constexpr
unsigned
HI
=
36
;
...
@@ -427,7 +427,7 @@ int main()
...
@@ -427,7 +427,7 @@ int main()
#endif
#endif
(
in_nchw_desc
,
in_nchw
,
wei_kcsr_desc
,
wei_kcsr
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
(
in_nchw_desc
,
in_nchw
,
wei_kcsr_desc
,
wei_kcsr
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
#if
1
#if
0
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host);
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host);
check_error(out_nkhw_host, out_nkhw_device);
check_error(out_nkhw_host, out_nkhw_device);
#elif
0
#elif
0
...
...
driver/device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
View file @
1f3870ca
...
@@ -103,19 +103,6 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
...
@@ -103,19 +103,6 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
constexpr
unsigned
HoPerThread
=
2
;
constexpr
unsigned
HoPerThread
=
2
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
WoPerThread
=
1
;
constexpr
unsigned
BlockSize
=
128
;
#elif 0
constexpr
unsigned
NPerBlock
=
2
;
constexpr
unsigned
KPerBlock
=
32
;
constexpr
unsigned
CPerBlock
=
4
;
constexpr
unsigned
HoPerBlock
=
2
;
constexpr
unsigned
WoPerBlock
=
32
;
constexpr
unsigned
KPerThread
=
4
;
constexpr
unsigned
CPerThread
=
2
;
constexpr
unsigned
HoPerThread
=
2
;
constexpr
unsigned
WoPerThread
=
2
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
unsigned
BlockSize
=
128
;
#endif
#endif
...
...
driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh
View file @
1f3870ca
...
@@ -75,10 +75,23 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
...
@@ -75,10 +75,23 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
constexpr unsigned KPerThread = 1;
constexpr unsigned KPerThread = 1;
constexpr unsigned CPerThread = 1;
constexpr unsigned CPerThread = 1;
constexpr unsigned ThreadPerClusterRow = 1;
constexpr unsigned
Gemm
ThreadPerClusterRow = 1;
constexpr unsigned ThreadPerClusterColumn = 4;
constexpr unsigned
Gemm
ThreadPerClusterColumn = 4;
constexpr unsigned BlockSize = 32;
constexpr unsigned BlockSize = 32;
#elif
0
constexpr
unsigned
BPerBlock
=
128
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
CPerBlock
=
2
;
constexpr
unsigned
BPerThread
=
8
;
constexpr
unsigned
KPerThread
=
8
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
GemmThreadPerClusterRow
=
4
;
constexpr
unsigned
GemmThreadPerClusterColumn
=
4
;
constexpr
unsigned
BlockSize
=
128
;
#elif 1
#elif 1
constexpr
unsigned
BPerBlock
=
128
;
constexpr
unsigned
BPerBlock
=
128
;
constexpr
unsigned
KPerBlock
=
64
;
constexpr
unsigned
KPerBlock
=
64
;
...
@@ -88,8 +101,11 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
...
@@ -88,8 +101,11 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
constexpr
unsigned
KPerThread
=
8
;
constexpr
unsigned
KPerThread
=
8
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
CPerThread
=
1
;
constexpr
unsigned
ThreadPerClusterRow
=
4
;
constexpr
unsigned
GemmThreadPerClusterRow
=
4
;
constexpr
unsigned
ThreadPerClusterColumn
=
4
;
constexpr
unsigned
GemmThreadPerClusterColumn
=
4
;
constexpr
unsigned
InBlockCopyThreadPerDim0
=
2
;
constexpr
unsigned
InBlockCopyThreadPerDim1
=
64
;
constexpr
unsigned
BlockSize
=
128
;
constexpr
unsigned
BlockSize
=
128
;
#endif
#endif
...
@@ -132,8 +148,10 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
...
@@ -132,8 +148,10 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
BPerThread
,
BPerThread
,
KPerThread
,
KPerThread
,
CPerThread
,
CPerThread
,
ThreadPerClusterRow
,
GemmThreadPerClusterRow
,
ThreadPerClusterColumn
>
GemmThreadPerClusterColumn
,
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim1
>
<<<
grid_dim
,
block_dim
>>>
(
in_cnhw_desc
,
<<<
grid_dim
,
block_dim
>>>
(
in_cnhw_desc
,
static_cast
<
T
*>
(
in_cnhw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
in_cnhw_device_buf
.
GetDeviceBuffer
()),
wei_srck_desc
,
wei_srck_desc
,
...
...
src/include/blockwise_2d_tensor_op.cuh
View file @
1f3870ca
...
@@ -162,11 +162,188 @@ blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
...
@@ -162,11 +162,188 @@ blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
}
}
template
<
unsigned
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
>
template
<
unsigned
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
>
__device__
void
blockwise_2d_tensor_copy
(
struct
blockwise_2d_tensor_copy_1
SrcDesc
,
Float
*
const
__restrict__
p_src
,
DstDesc
,
Float
*
__restrict__
p_dst
,
SrcOpLengths
)
{
{
constexpr
auto
dst_from_src_reorder
=
Sequence
<
0
,
1
>
{};
__device__
void
run
(
Float
*
const
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
constexpr
auto
dst_from_src_reorder
=
Sequence
<
0
,
1
>
{};
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src
<
BlockSize
>
(
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src
<
BlockSize
>
(
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
dst_from_src_reorder
);
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
dst_from_src_reorder
);
}
}
};
template
<
unsigned
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
unsigned
ThreadPerDim0
,
unsigned
ThreadPerDim1
>
struct
blockwise_2d_tensor_copy_2
{
unsigned
mThreadId0
;
unsigned
mThreadId1
;
__device__
blockwise_2d_tensor_copy_2
()
{
mThreadId0
=
get_thread_local_1d_id
()
/
ThreadPerDim1
;
mThreadId1
=
get_thread_local_1d_id
()
-
mThreadId0
*
ThreadPerDim1
;
}
__device__
void
run
(
Float
*
const
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
if
(
get_thread_local_1d_id
()
>=
ThreadPerDim0
*
ThreadPerDim1
)
return
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
constexpr
unsigned
L0
=
SrcOpLengths
{}.
Get
(
I0
);
constexpr
unsigned
L1
=
SrcOpLengths
{}.
Get
(
I1
);
constexpr
unsigned
Dim0Loop
=
L0
/
ThreadPerDim0
;
constexpr
bool
d0_has_tail
=
(
L0
>
ThreadPerDim0
*
Dim0Loop
);
constexpr
unsigned
Dim1V4Loop
=
L1
/
(
ThreadPerDim1
*
4
);
constexpr
unsigned
Dim1V2Loop
=
(
L1
-
Dim1V4Loop
*
(
ThreadPerDim1
*
4
))
/
(
ThreadPerDim1
*
2
);
constexpr
unsigned
Dim1V1Loop
=
(
L1
-
Dim1V4Loop
*
(
ThreadPerDim1
*
4
)
-
Dim1V2Loop
*
(
ThreadPerDim1
*
2
))
/
ThreadPerDim1
;
constexpr
bool
d1_has_tail
=
(
L1
>
ThreadPerDim1
*
(
4
*
Dim1V4Loop
+
2
*
Dim1V2Loop
+
Dim1V1Loop
));
for
(
unsigned
d0loop
=
0
;
d0loop
<
Dim0Loop
;
++
d0loop
)
{
unsigned
did0
=
d0loop
*
ThreadPerDim0
+
mThreadId0
;
// v4
for
(
unsigned
d1v4loop
=
0
;
d1v4loop
<
Dim1V4Loop
;
++
d1v4loop
)
{
unsigned
did1
=
d1v4loop
*
4
*
ThreadPerDim1
+
4
*
mThreadId1
;
for
(
unsigned
i
=
0
;
i
<
4
;
++
i
)
{
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
+
i
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
+
i
);
p_dst
[
dindex
]
=
p_src
[
sindex
];
}
}
// v2
for
(
unsigned
d1v2loop
=
0
;
d1v2loop
<
Dim1V2Loop
;
++
d1v2loop
)
{
unsigned
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
d1v2loop
*
2
*
ThreadPerDim1
+
2
*
mThreadId1
;
for
(
unsigned
i
=
0
;
i
<
2
;
++
i
)
{
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
+
i
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
+
i
);
p_dst
[
dindex
]
=
p_src
[
sindex
];
}
}
// v1
for
(
unsigned
d1v1loop
=
0
;
d1v1loop
<
Dim1V1Loop
;
++
d1v1loop
)
{
unsigned
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
Dim1V2Loop
*
2
*
ThreadPerDim1
+
d1v1loop
*
ThreadPerDim1
+
mThreadId1
;
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
p_dst
[
dindex
]
=
p_src
[
sindex
];
}
// dim-1 tail
if
(
d1_has_tail
)
{
unsigned
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
Dim1V2Loop
*
2
*
ThreadPerDim1
+
Dim1V1Loop
*
ThreadPerDim1
+
mThreadId1
;
if
(
did1
<
L1
)
{
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
p_dst
[
dindex
]
=
p_src
[
sindex
];
}
}
}
// dim-0 tail
if
(
d0_has_tail
)
{
unsigned
did0
=
Dim0Loop
*
ThreadPerDim0
+
mThreadId0
;
if
(
did0
<
L0
)
{
// v4
for
(
unsigned
d1v4loop
=
0
;
d1v4loop
<
Dim1V4Loop
;
++
d1v4loop
)
{
unsigned
did1
=
d1v4loop
*
4
*
ThreadPerDim1
+
4
*
mThreadId1
;
for
(
unsigned
i
=
0
;
i
<
4
;
++
i
)
{
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
+
i
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
+
i
);
p_dst
[
dindex
]
=
p_src
[
sindex
];
}
}
// v2
for
(
unsigned
d1v2loop
=
0
;
d1v2loop
<
Dim1V2Loop
;
++
d1v2loop
)
{
unsigned
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
d1v2loop
*
2
*
ThreadPerDim1
+
2
*
mThreadId1
;
for
(
unsigned
i
=
0
;
i
<
2
;
++
i
)
{
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
+
i
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
+
i
);
p_dst
[
dindex
]
=
p_src
[
sindex
];
}
}
// v1
for
(
unsigned
d1v1loop
=
0
;
d1v1loop
<
Dim1V1Loop
;
++
d1v1loop
)
{
unsigned
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
Dim1V2Loop
*
2
*
ThreadPerDim1
+
d1v1loop
*
ThreadPerDim1
+
mThreadId1
;
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
p_dst
[
dindex
]
=
p_src
[
sindex
];
}
// tail
if
(
d1_has_tail
)
{
unsigned
did1
=
Dim1V4Loop
*
4
*
ThreadPerDim1
+
Dim1V2Loop
*
2
*
ThreadPerDim1
+
Dim1V1Loop
*
ThreadPerDim1
+
mThreadId1
;
if
(
did1
<
L1
)
{
const
unsigned
sindex
=
src_desc
.
Get1dIndex
(
did0
,
did1
);
const
unsigned
dindex
=
dst_desc
.
Get1dIndex
(
did0
,
did1
);
p_dst
[
dindex
]
=
p_src
[
sindex
];
}
}
}
}
}
};
src/include/common.cuh
View file @
1f3870ca
#pragma once
#pragma once
#define WARPSIZE 32;
template
<
class
T1
,
class
T2
>
template
<
class
T1
,
class
T2
>
struct
is_same
struct
is_same
{
{
...
...
src/include/gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh
View file @
1f3870ca
...
@@ -153,6 +153,7 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
...
@@ -153,6 +153,7 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
for
(
unsigned
c_block_data_begin
=
0
;
c_block_data_begin
<
in_nchw_global_desc
.
GetLength
(
I1
);
for
(
unsigned
c_block_data_begin
=
0
;
c_block_data_begin
<
in_nchw_global_desc
.
GetLength
(
I1
);
c_block_data_begin
+=
CPerBlock
,
__syncthreads
())
c_block_data_begin
+=
CPerBlock
,
__syncthreads
())
{
{
#if 1
// input: global mem to LDS,
// input: global mem to LDS,
// convert [N,C,Hi,Wi] to [C,Hi,Wi,N]
// convert [N,C,Hi,Wi] to [C,Hi,Wi,N]
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src
<
BlockSize
>
(
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src
<
BlockSize
>
(
...
@@ -165,7 +166,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
...
@@ -165,7 +166,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
p_in_block
,
p_in_block
,
in_nchw_block_desc
.
GetLengths
(),
in_nchw_block_desc
.
GetLengths
(),
reorder_chwn_from_nchw
);
reorder_chwn_from_nchw
);
#endif
#if 1
// weight: global mem to LDS,
// weight: global mem to LDS,
// format is [S,R,C,K], no conversion needed
// format is [S,R,C,K], no conversion needed
blockwise_4d_tensor_copy
<
BlockSize
>
(
blockwise_4d_tensor_copy
<
BlockSize
>
(
...
@@ -175,6 +178,7 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
...
@@ -175,6 +178,7 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
wei_srck_block_desc
,
wei_srck_block_desc
,
p_wei_block
,
p_wei_block
,
wei_srck_block_desc
.
GetLengths
());
wei_srck_block_desc
.
GetLengths
());
#endif
__syncthreads
();
__syncthreads
();
...
...
src/include/gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh
View file @
1f3870ca
...
@@ -20,8 +20,10 @@ template <unsigned GridSize,
...
@@ -20,8 +20,10 @@ template <unsigned GridSize,
unsigned
BPerThread
,
unsigned
BPerThread
,
unsigned
KPerThread
,
unsigned
KPerThread
,
unsigned
CPerThread
,
unsigned
CPerThread
,
unsigned
ThreadPerClusterRow
,
unsigned
GemmThreadPerClusterRow
,
unsigned
ThreadPerClusterColumn
>
unsigned
GemmThreadPerClusterColumn
,
unsigned
InBlockCopyThreadPerDim0
,
unsigned
InBlockCopyThreadPerDim1
>
__global__
void
__global__
void
gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw
(
InGlobalDesc
,
gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw
(
InGlobalDesc
,
Float
*
const
__restrict__
p_in_global
,
Float
*
const
__restrict__
p_in_global
,
...
@@ -104,6 +106,26 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
...
@@ -104,6 +106,26 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
}
}
#endif
#endif
#if 1
// blockwise 2d copy
const
auto
blockwise_2d_copy
=
blockwise_2d_tensor_copy_1
<
BlockSize
,
Float
,
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
.
GetLengths
())
>
{};
#elif 0
// blockwise 2d copy
const
auto
blockwise_2d_copy
=
blockwise_2d_tensor_copy_2
<
BlockSize
,
Float
,
decltype
(
in_cb_global_desc
),
decltype
(
in_cb_block_desc
),
decltype
(
in_cb_block_desc
.
GetLengths
()),
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim1
>
{};
#endif
// a series of blockwise GEMM
// a series of blockwise GEMM
// c_mtx += transpose(a_mtx) * b_mtx
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx and b_mtx saved in LDS, c_mtx saved in register
// a_mtx and b_mtx saved in LDS, c_mtx saved in register
...
@@ -130,8 +152,8 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
...
@@ -130,8 +152,8 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
false
,
false
,
false
,
false
,
CPerThread
,
CPerThread
,
ThreadPerClusterRow
,
Gemm
ThreadPerClusterRow
,
ThreadPerClusterColumn
,
Gemm
ThreadPerClusterColumn
,
true
>
{};
true
>
{};
// LDS
// LDS
...
@@ -152,12 +174,9 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
...
@@ -152,12 +174,9 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
{
{
// input: global mem to LDS,
// input: global mem to LDS,
// formmat is [CPerBlock,BPerBlock + BGhostRead]
// formmat is [CPerBlock,BPerBlock + BGhostRead]
blockwise_2d_tensor_copy
<
BlockSize
>
(
blockwise_2d_copy
.
run
(
in_cb_global_desc
,
p_in_global
+
in_cb_global_desc
.
Get1dIndex
(
c_block_data_begin
,
b_block_data_begin
),
p_in_global
+
in_cb_global_desc
.
Get1dIndex
(
c_block_data_begin
,
b_block_data_begin
),
in_cb_block_desc
,
p_in_block
);
p_in_block
,
in_cb_block_desc
.
GetLengths
());
// weight: global mem to LDS,
// weight: global mem to LDS,
// format is [S,R,CPerBlock,KPerBlock]
// format is [S,R,CPerBlock,KPerBlock]
...
@@ -245,22 +264,6 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
...
@@ -245,22 +264,6 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
p_out_global
[
out_knhw_global_desc
.
Get1dIndex
(
k_data
,
n_data
,
h_data
,
w_data
)]
=
p_out_global
[
out_knhw_global_desc
.
Get1dIndex
(
k_data
,
n_data
,
h_data
,
w_data
)]
=
p_out_thread
[
out_kb_thread_desc
.
Get1dIndex
(
k
,
b
)];
p_out_thread
[
out_kb_thread_desc
.
Get1dIndex
(
k
,
b
)];
#endif
#endif
#if 0
if(get_block_1d_id() == 0)
{
printf("%u %u, k %u b %u, k_data %u n_data %u h_data %u w_data %u %f\n",
get_block_1d_id(),
get_thread_local_1d_id(),
k,
b,
k_data,
n_data,
h_data,
w_data,
p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]);
}
#endif
}
}
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment