Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
471830a0
Commit
471830a0
authored
Apr 09, 2019
by
Chao Liu
Browse files
tidy yp
parent
1bd880a6
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
98 additions
and
111 deletions
+98
-111
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp
+31
-31
src/include/ConstantTensorDescriptor.hip.hpp
src/include/ConstantTensorDescriptor.hip.hpp
+4
-2
src/include/blockwise_4d_tensor_op.hip.hpp
src/include/blockwise_4d_tensor_op.hip.hpp
+6
-5
src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp
...dwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp
+9
-17
src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
...implicit_gemm_v1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
+17
-26
src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp
...ise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp
+24
-28
src/include/threadwise_gemm.hip.hpp
src/include/threadwise_gemm.hip.hpp
+7
-2
No files found.
driver/device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp
View file @
471830a0
...
...
@@ -346,37 +346,37 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
#elif 0
GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer
#endif
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_chwn_desc
),
decltype
(
wei_cyxk_desc
),
decltype
(
out_khwn_desc
),
NPerBlock
,
KPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerThread
,
KPerThread
,
HoPerThread
,
WoPerThread
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
,
Sequence
<
InBlockCopy_ThreadPerDimC
,
InBlockCopy_ThreadPerDimH
,
InBlockCopy_ThreadPerDimW
,
InBlockCopy_ThreadPerDimN
>
,
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
,
OutThreadCopyDataPerWrite
>
{};
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_chwn_desc
),
decltype
(
wei_cyxk_desc
),
decltype
(
out_khwn_desc
),
NPerBlock
,
KPerBlock
,
CPerBlock
,
HoPerBlock
,
WoPerBlock
,
NPerThread
,
KPerThread
,
HoPerThread
,
WoPerThread
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
,
Sequence
<
InBlockCopy_ThreadPerDimC
,
InBlockCopy_ThreadPerDimH
,
InBlockCopy_ThreadPerDimW
,
InBlockCopy_ThreadPerDimN
>
,
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
,
OutThreadCopyDataPerWrite
>
{};
float
time
=
launch_kernel
(
run_gridwise_convolution
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
...
...
src/include/ConstantTensorDescriptor.hip.hpp
View file @
471830a0
...
...
@@ -381,7 +381,8 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
constexpr
auto
I7
=
Number
<
7
>
{};
constexpr
auto
I8
=
Number
<
8
>
{};
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u %u}
\n
"
,
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
"%u}
\n
"
,
s
,
desc
.
GetDimension
(),
desc
.
GetLength
(
I0
),
...
...
@@ -416,7 +417,8 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
constexpr
auto
I8
=
Number
<
8
>
{};
constexpr
auto
I9
=
Number
<
9
>
{};
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u %u %u}
\n
"
,
printf
(
"%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u "
"%u %u %u}
\n
"
,
s
,
desc
.
GetDimension
(),
desc
.
GetLength
(
I0
),
...
...
src/include/blockwise_4d_tensor_op.hip.hpp
View file @
471830a0
...
...
@@ -577,8 +577,8 @@ struct Blockwise4dTensorCopy3
iloop_d3
*
thread_per_d3
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
+
mDstMyThreadOffset
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
+
mSrcMyThreadOffset
]));
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
+
mSrcMyThreadOffset
]));
}
}
}
...
...
@@ -612,7 +612,8 @@ struct Blockwise4dTensorCopy3
return
DataPerRead
*
nloop_d0
*
nloop_d1
*
nloop_d2
*
nloop_d3
;
}
__device__
void
RunLoadRegisterClipboard
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_clipboard
)
const
__device__
void
RunLoadRegisterClipboard
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_clipboard
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -670,8 +671,8 @@ struct Blockwise4dTensorCopy3
iloop_d3
*
thread_per_d3
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
&
p_clipboard
[
dst_offset
]))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
+
mSrcMyThreadOffset
]));
*
(
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
+
mSrcMyThreadOffset
]));
}
}
}
...
...
src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp
View file @
471830a0
...
...
@@ -43,8 +43,9 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
Float
*
const
__restrict__
p_out_global
)
const
{
// be careful of this assertion
static_assert
(
NPerThread
<=
NPerBlock
&&
NPerBlock
%
NPerThread
==
0
,
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"
);
static_assert
(
NPerThread
<=
NPerBlock
&&
NPerBlock
%
NPerThread
==
0
,
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -69,8 +70,9 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
constexpr
index_t
WiPerBlock
=
WoPerBlock
+
X
-
1
;
// divide block work: [K, Ho, Wo, N]
static_assert
(
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
C
%
CPerBlock
==
0
&&
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
static_assert
(
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
C
%
CPerBlock
==
0
&&
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
HBlockWork
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
;
...
...
@@ -101,8 +103,7 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
mod_conv
::
max
(
index_t
(
4
),
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
constexpr
auto
in_chwn_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{},
Number
<
max_align
>
{});
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{},
Number
<
max_align
>
{});
constexpr
auto
wei_ek_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
*
Y
*
X
,
KPerBlock
>
{},
Number
<
max_align
>
{});
...
...
@@ -280,17 +281,8 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
),
W1
,
W2
,
N
/
(
N1
*
N2
),
N1
,
N2
>
{});
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
),
W1
,
W2
,
N
/
(
N1
*
N2
),
N1
,
N2
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
1
,
1
,
N2
>
{});
...
...
src/include/gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
View file @
471830a0
...
...
@@ -43,8 +43,9 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer
Float
*
const
__restrict__
p_out_global
)
const
{
// be careful of this assertion
static_assert
(
NPerThread
<=
NPerBlock
&&
NPerBlock
%
NPerThread
==
0
,
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"
);
static_assert
(
NPerThread
<=
NPerBlock
&&
NPerBlock
%
NPerThread
==
0
,
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -72,8 +73,9 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer
static_assert
(
C
%
(
2
*
CPerBlock
)
==
0
,
"C cannot be evenly divided"
);
// divide block work: [K, Ho, Wo, N]
static_assert
(
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
C
%
CPerBlock
==
0
&&
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
static_assert
(
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
C
%
CPerBlock
==
0
&&
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
HBlockWork
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
;
...
...
@@ -104,8 +106,7 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer
mod_conv
::
max
(
index_t
(
4
),
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
);
constexpr
auto
in_chwn_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{},
Number
<
max_align
>
{});
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{},
Number
<
max_align
>
{});
constexpr
auto
wei_ek_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
*
Y
*
X
,
KPerBlock
>
{},
Number
<
max_align
>
{});
...
...
@@ -250,16 +251,15 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
// a series of batched GEMM
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
blockwise_batch_gemm
.
Run
(
p_wei_block_now
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_now
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_out_thread
);
blockwise_batch_gemm
.
Run
(
p_wei_block_now
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_now
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_out_thread
);
}
}
...
...
@@ -291,10 +291,10 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer
{
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
blockwise_batch_gemm
.
Run
(
p_wei_block_double
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_double
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_out_thread
);
blockwise_batch_gemm
.
Run
(
p_wei_block_double
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block_double
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_out_thread
);
}
}
...
...
@@ -376,17 +376,8 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
),
W1
,
W2
,
N
/
(
N1
*
N2
),
N1
,
N2
>
{});
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
),
W1
,
W2
,
N
/
(
N1
*
N2
),
N1
,
N2
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
1
,
1
,
N2
>
{});
...
...
src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp
View file @
471830a0
...
...
@@ -43,8 +43,9 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
Float
*
const
__restrict__
p_out_global
)
const
{
// be careful of this assertion
static_assert
(
NPerThread
<=
NPerBlock
&&
NPerBlock
%
NPerThread
==
0
,
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"
);
static_assert
(
NPerThread
<=
NPerBlock
&&
NPerBlock
%
NPerThread
==
0
,
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -69,8 +70,9 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
constexpr
index_t
WiPerBlock
=
WoPerBlock
+
X
-
1
;
// divide block work: [K, Ho, Wo, N]
static_assert
(
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
C
%
CPerBlock
==
0
&&
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
static_assert
(
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
C
%
CPerBlock
==
0
&&
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
HBlockWork
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
;
...
...
@@ -93,7 +95,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
// 2d tensor view of gridwise weight
constexpr
auto
wei_ck_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
K
>
{},
Sequence
<
Y
*
X
*
K
,
1
>
{});
constexpr
auto
wei_ck_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
C
,
K
>
{},
Sequence
<
Y
*
X
*
K
,
1
>
{});
// tensor view of blockwise input and weight in LDS
// be careful of alignment
...
...
@@ -124,7 +127,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
// blockwise wei copy
// format is [CPerBlock, KPerBlock]
const
auto
blockwise_wei_copy
=
#if 0//debug
#if 0
//
debug
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(wei_ck_global_desc),
...
...
@@ -139,14 +142,16 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
WeiBlockCopyDataPerRead
>
{};
#endif
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr
auto
a_cxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_ck_block_desc
.
GetStride
(
I0
)
>
{});
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr
auto
a_cxk_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_ck_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_cxwn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
...
...
@@ -180,7 +185,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
GemmDataPerReadB
>
{};
// LDS: be careful of alignment
constexpr
index_t
in_block_space
=
in_chwn_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
in_block_space
=
in_chwn_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_space
=
wei_ck_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
__shared__
Float
p_in_block
[
in_block_space
];
...
...
@@ -227,8 +232,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
{
// weight: global mem to LDS
blockwise_wei_copy
.
Run
(
p_wei_global_block_offset
+
wei_cyxk_global_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_wei_block
);
wei_cyxk_global_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_wei_block
);
__syncthreads
();
...
...
@@ -297,17 +302,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
),
W1
,
W2
,
N
/
(
N1
*
N2
),
N1
,
N2
>
{});
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
),
W1
,
W2
,
N
/
(
N1
*
N2
),
N1
,
N2
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
1
,
1
,
N2
>
{});
...
...
src/include/threadwise_gemm.hip.hpp
View file @
471830a0
#pragma once
template
<
class
Float
,
class
SrcMatrix
,
class
DstMatrix
,
index_t
NRow
,
index_t
NCol
,
index_t
DataPerRead
>
template
<
class
Float
,
class
SrcMatrix
,
class
DstMatrix
,
index_t
NRow
,
index_t
NCol
,
index_t
DataPerRead
>
__device__
void
threadwise_matrix_copy
(
SrcMatrix
,
const
Float
*
__restrict__
p_src
,
DstMatrix
,
...
...
@@ -22,7 +27,7 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
const
index_t
src_index
=
src_mtx
.
Get1dIndex
(
i
,
j
);
const
index_t
dst_index
=
dst_mtx
.
Get1dIndex
(
i
,
j
);
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
])
=
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment