Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
cf9bd973
Commit
cf9bd973
authored
Feb 25, 2020
by
Chao Liu
Browse files
refactoring blockwise gemm
parent
7d09790a
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
173 additions
and
174 deletions
+173
-174
composable_kernel/include/tensor_operation/blockwise_gemm.hpp
...osable_kernel/include/tensor_operation/blockwise_gemm.hpp
+108
-74
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
+8
-12
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
...sable_kernel/include/tensor_operation/threadwise_gemm.hpp
+26
-73
composable_kernel/include/tensor_operation/threadwise_generic_tensor_op.hpp
...include/tensor_operation/threadwise_generic_tensor_op.hpp
+5
-9
composable_kernel/include/tensor_operation/threadwise_generic_tensor_op_deprecated.hpp
...sor_operation/threadwise_generic_tensor_op_deprecated.hpp
+20
-0
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+1
-1
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+5
-5
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm.hpp
View file @
cf9bd973
...
@@ -2,7 +2,9 @@
...
@@ -2,7 +2,9 @@
#define CK_BLOCKWISE_GEMM_HPP
#define CK_BLOCKWISE_GEMM_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "threadwise_gemm.hpp"
#include "threadwise_gemm.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -38,16 +40,22 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -38,16 +40,22 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
__device__
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
()
__device__
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
()
{
{
static_assert
(
BlockMatrixA
::
GetNumOfDimension
()
==
2
&&
BlockMatrixB
::
GetNumOfDimension
()
==
2
&&
ThreadMatrixC
::
GetNumOfDimension
()
==
2
,
"wrong! A, B, C matrix should be 2D tensors"
);
constexpr
index_t
ThreadPerLevel1Cluster
=
MLevel0ThreadCluster
*
NLevel0ThreadCluster
*
constexpr
index_t
ThreadPerLevel1Cluster
=
MLevel0ThreadCluster
*
NLevel0ThreadCluster
*
MLevel1ThreadCluster
*
NLevel1ThreadCluster
;
MLevel1ThreadCluster
*
NLevel1ThreadCluster
;
static_assert
(
BlockSize
==
ThreadPerLevel1Cluster
,
"wrong! wrong blocksize
\n
"
);
static_assert
(
BlockSize
==
ThreadPerLevel1Cluster
,
"wrong! wrong blocksize
\n
"
);
static_assert
(
BlockMatrixA
::
NRow
()
==
BlockMatrixB
::
NRow
()
,
static_assert
(
BlockMatrixA
::
GetLengths
()[
0
]
==
BlockMatrixB
::
GetLengths
()[
0
]
,
"wrong! K dimension not consistent
\n
"
);
"wrong! K dimension not consistent
\n
"
);
constexpr
index_t
M
=
BlockMatrixA
::
NCol
()
;
// A is transposed
constexpr
index_t
M
=
BlockMatrixA
::
GetLengths
()[
1
]
;
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
::
NCol
()
;
constexpr
index_t
N
=
BlockMatrixB
::
GetLengths
()[
1
]
;
static_assert
(
M
%
(
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
)
==
0
&&
static_assert
(
M
%
(
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
)
==
0
&&
N
%
(
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
)
==
0
,
N
%
(
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
)
==
0
,
...
@@ -59,14 +67,14 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -59,14 +67,14 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
mMyThreadOffsetA
=
BlockMatrixA
::
GetOffsetFromMultiIndex
(
0
,
c_thread_mtx_index
.
row
);
mMyThreadOffsetA
=
BlockMatrixA
::
CalculateOffset
({
0
,
c_thread_mtx_index
.
row
}
);
mMyThreadOffsetB
=
BlockMatrixB
::
GetOffsetFromMultiIndex
(
0
,
c_thread_mtx_index
.
col
);
mMyThreadOffsetB
=
BlockMatrixB
::
CalculateOffset
({
0
,
c_thread_mtx_index
.
col
}
);
}
}
__device__
static
constexpr
auto
GetThreadMatrixCLengths
()
__device__
static
constexpr
auto
GetThreadMatrixCLengths
()
{
{
constexpr
index_t
M
=
BlockMatrixA
::
NCol
()
;
// A is transposed
constexpr
index_t
M
=
BlockMatrixA
::
GetLengths
()[
1
]
;
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
::
NCol
()
;
constexpr
index_t
N
=
BlockMatrixB
::
GetLengths
()[
1
]
;
constexpr
index_t
MRepeat
=
constexpr
index_t
MRepeat
=
M
/
(
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
);
M
/
(
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
);
...
@@ -125,8 +133,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -125,8 +133,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
constexpr
index_t
K
=
a_block_mtx
.
NRow
();
constexpr
index_t
MPerThread
=
c_thread_mtx
.
NRow
()
;
constexpr
index_t
MPerThread
=
c_thread_mtx
.
GetLengths
()[
0
]
;
constexpr
index_t
NPerThread
=
c_thread_mtx
.
NCol
()
;
constexpr
index_t
NPerThread
=
c_thread_mtx
.
GetLengths
()[
1
]
;
constexpr
index_t
MPerLevel1Cluster
=
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
;
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
;
...
@@ -138,25 +146,36 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -138,25 +146,36 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
// thread A, B for GEMM
// thread A, B for GEMM
constexpr
auto
a_thread_mtx
=
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
make_native_tensor_descriptor_packed
(
Sequence
<
KPerThreadLoop
,
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
constexpr
auto
b_thread_mtx
=
make_
ConstantMatrixD
escriptor_packed
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
make_
native_tensor_d
escriptor_packed
(
Sequence
<
KPerThreadLoop
,
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
auto
a_thread_copy
=
ThreadwiseMatrixSliceCopy
<
BlockMatrixA
,
constexpr
auto
a_thread_copy
=
decltype
(
a_thread_mtx
),
ThreadwiseGenericTensorSliceCopy_v4r2
<
BlockMatrixA
,
KPerThreadLoop
,
decltype
(
a_thread_mtx
),
MPerThreadSubC
,
Sequence
<
KPerThreadLoop
,
MPerThreadSubC
>
,
ThreadGemmADataPerRead_M
>
{};
Sequence
<
0
,
1
>
,
1
,
constexpr
auto
b_thread_copy
=
ThreadwiseMatrixSliceCopy
<
BlockMatrixB
,
ThreadGemmADataPerRead_M
,
decltype
(
b_thread_mtx
),
ThreadGemmADataPerRead_M
,
KPerThreadLoop
,
AddressSpace
::
Lds
,
NPerThreadSubC
,
AddressSpace
::
Vgpr
,
ThreadGemmBDataPerRead_N
>
{};
InMemoryDataOperation
::
Set
>
({
0
,
0
},
{
0
,
0
});
constexpr
auto
b_thread_copy
=
ThreadwiseGenericTensorSliceCopy_v4r2
<
BlockMatrixB
,
decltype
(
b_thread_mtx
),
Sequence
<
KPerThreadLoop
,
NPerThreadSubC
>
,
Sequence
<
0
,
1
>
,
1
,
ThreadGemmBDataPerRead_N
,
ThreadGemmBDataPerRead_N
,
AddressSpace
::
Lds
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
>
({
0
,
0
},
{
0
,
0
});
constexpr
auto
threadwise_gemm
=
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmTransANormalBNormalC
<
decltype
(
a_thread_mtx
),
ThreadwiseGemmTransANormalBNormalC
<
decltype
(
a_thread_mtx
),
...
@@ -171,9 +190,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -171,9 +190,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
{
a_thread_copy
.
Run
(
a_thread_copy
.
Run
(
p_a_block
+
a_block_mtx
.
CalculateOffset
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
)
+
p_a_block
+
a_block_mtx
.
CalculateOffset
({
k_begin
,
m_repeat
*
MPerLevel1Cluster
})
+
mMyThreadOffsetA
,
mMyThreadOffsetA
,
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
m_repeat
*
MPerThreadSubC
));
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
{
0
,
m_repeat
*
MPerThreadSubC
}
));
}
}
#pragma unroll
#pragma unroll
...
@@ -181,9 +201,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -181,9 +201,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
{
b_thread_copy
.
Run
(
b_thread_copy
.
Run
(
p_b_block
+
b_block_mtx
.
CalculateOffset
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
)
+
p_b_block
+
b_block_mtx
.
CalculateOffset
({
k_begin
,
n_repeat
*
NPerLevel1Cluster
})
+
mMyThreadOffsetB
,
mMyThreadOffsetB
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
n_repeat
*
NPerThreadSubC
));
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
{
0
,
n_repeat
*
NPerThreadSubC
}
));
}
}
// C += A * B
// C += A * B
...
@@ -217,34 +238,47 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -217,34 +238,47 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
// thread A, B
// thread A, B
constexpr
auto
a_thread_mtx
=
constexpr
auto
a_thread_mtx
=
make_
ConstantMatrixD
escriptor_packed
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
make_
native_tensor_d
escriptor_packed
(
Sequence
<
KPerThreadLoop
,
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
constexpr
auto
b_thread_mtx
=
make_
ConstantMatrixD
escriptor_packed
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
make_
native_tensor_d
escriptor_packed
(
Sequence
<
KPerThreadLoop
,
NPerThread
>
{});
// thread A-sub, B-sub
// thread A-sub, B-sub
constexpr
auto
a_thread_sub_mtx
=
a_thread_mtx
.
MakeSubMatrixDescriptor
(
constexpr
auto
a_thread_sub_mtx
=
make_native_tensor_descriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{});
Sequence
<
KPerThreadLoop
,
MPerThreadSubC
>
{},
Sequence
<
MPerThread
,
1
>
{});
constexpr
auto
b_thread_sub_mtx
=
b_thread_mtx
.
MakeSubMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{});
constexpr
auto
b_thread_sub_mtx
=
make_native_tensor_descriptor
(
Sequence
<
KPerThreadLoop
,
NPerThreadSubC
>
{},
Sequence
<
NPerThread
,
1
>
{});
// thread C-sub
// thread C-sub
constexpr
auto
c_thread_sub_mtx
=
ThreadMatrixC
::
MakeSubMatrixD
escriptor
(
constexpr
auto
c_thread_sub_mtx
=
make_native_tensor_d
escriptor
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThread
SubC
>
{});
Sequence
<
MPerThreadSubC
,
NPerThreadSubC
>
{},
Sequence
<
NPerThread
,
1
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
auto
a_thread_copy
=
ThreadwiseMatrixSliceCopy
<
BlockMatrixA
,
constexpr
auto
a_thread_copy
=
decltype
(
a_thread_mtx
),
ThreadwiseGenericTnesorSliceCopy_v4r2
<
BlockMatrixA
,
KPerThreadLoop
,
decltype
(
a_thread_sub_mtx
),
MPerThreadSubC
,
decltype
(
a_thread_sub_mtx
.
GetLengths
()),
ThreadGemmADataPerRead_M
>
{};
Sequence
<
0
,
1
>
,
1
,
constexpr
auto
b_thread_copy
=
ThreadwiseMatrixSliceCopy
<
BlockMatrixB
,
ThreadGemmADataPerRead_M
,
decltype
(
b_thread_mtx
),
ThreadGemmADataPerRead_M
,
KPerThreadLoop
,
AddressSpace
::
Lds
,
NPerThreadSubC
,
AddressSpace
::
Vgpr
,
ThreadGemmBDataPerRead_N
>
{};
InMemoryDataOperation
::
Set
>
({
0
,
0
},
{
0
,
0
});
constexpr
auto
b_thread_copy
=
ThreadwiseGenericTnesorSliceCopy_v4r2
<
BlockMatrixB
,
decltype
(
b_thread_mtx
),
decltype
(
b_thread_sub_mtx
.
GetLengths
()),
Sequence
<
0
,
1
>
,
1
,
ThreadGemmBDataPerRead_N
,
ThreadGemmBDataPerRead_N
,
AddressSpace
::
Lds
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
>
({
0
,
0
},
{
0
,
0
});
constexpr
auto
threadwise_gemm
=
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmTransANormalBNormalC
<
decltype
(
a_thread_sub_mtx
),
ThreadwiseGemmTransANormalBNormalC
<
decltype
(
a_thread_sub_mtx
),
...
@@ -261,77 +295,77 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -261,77 +295,77 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
b_thread_copy
.
Run
(
p_b_block_off
,
p_b_thread
);
b_thread_copy
.
Run
(
p_b_block_off
,
p_b_thread
);
// read B_sub_1
// read B_sub_1
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
0
,
NPerLevel1Cluster
),
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
{
0
,
NPerLevel1Cluster
}
),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
NPerThreadSubC
));
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
{
0
,
NPerThreadSubC
}
));
// read A_sub_1
// read A_sub_1
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
0
,
MPerLevel1Cluster
),
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
{
0
,
MPerLevel1Cluster
}
),
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
MPerThreadSubC
));
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
{
0
,
MPerThreadSubC
}
));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
,
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
NPerThreadSubC
),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
{
0
,
NPerThreadSubC
}
),
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
0
,
NPerThreadSubC
));
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
{
0
,
NPerThreadSubC
}
));
#pragma unroll
#pragma unroll
// loop over rest of k
// loop over rest of k
for
(
index_t
k
=
KPerThreadLoop
;
k
<
K
;
k
+=
KPerThreadLoop
)
for
(
index_t
k
=
KPerThreadLoop
;
k
<
K
;
k
+=
KPerThreadLoop
)
{
{
// read A_sub_0
// read A_sub_0
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
k
,
0
),
p_a_thread
);
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
{
k
,
0
}
),
p_a_thread
);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
MPerThreadSubC
),
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
{
0
,
MPerThreadSubC
}
),
p_b_thread
,
p_b_thread
,
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
MPerThreadSubC
,
0
));
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
{
MPerThreadSubC
,
0
}
));
// read B_sub_0
// read B_sub_0
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
k
,
0
),
p_b_thread
);
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
{
k
,
0
}
),
p_b_thread
);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
MPerThreadSubC
),
threadwise_gemm
.
Run
(
p_
b
_thread
+
b
_thread_mtx
.
CalculateOffset
(
0
,
N
PerThreadSubC
),
p_
a
_thread
+
a
_thread_mtx
.
CalculateOffset
(
{
0
,
M
PerThreadSubC
}
),
p_
c
_thread
+
p_
b
_thread
+
b_thread_mtx
.
CalculateOffset
({
0
,
NPerThreadSubC
}),
ThreadMatrixC
::
CalculateOffset
(
MPerThreadSubC
,
NPerThreadSubC
));
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
{
MPerThreadSubC
,
NPerThreadSubC
}
));
// read B_sub_1
// read B_sub_1
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
k
,
NPerLevel1Cluster
),
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
{
k
,
NPerLevel1Cluster
}
),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
NPerThreadSubC
));
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
{
0
,
NPerThreadSubC
}
));
// read A_sub_1
// read A_sub_1
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
k
,
MPerLevel1Cluster
),
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
{
k
,
MPerLevel1Cluster
}
),
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
MPerThreadSubC
));
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
{
0
,
MPerThreadSubC
}
));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
,
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
NPerThreadSubC
),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
{
0
,
NPerThreadSubC
}
),
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
0
,
NPerThreadSubC
));
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
{
0
,
NPerThreadSubC
}
));
}
}
// C_sub_10 += transpose(A_sub_1) * B_sub_0
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
MPerThreadSubC
),
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
{
0
,
MPerThreadSubC
}
),
p_b_thread
,
p_b_thread
,
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
MPerThreadSubC
,
0
));
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
{
MPerThreadSubC
,
0
}
));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
MPerThreadSubC
),
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
{
0
,
MPerThreadSubC
}
),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
NPerThreadSubC
),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
{
0
,
NPerThreadSubC
}
),
p_c_thread
+
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
MPerThreadSubC
,
NPerThreadSubC
));
ThreadMatrixC
::
CalculateOffset
(
{
MPerThreadSubC
,
NPerThreadSubC
}
));
}
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
__device__
void
Run
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
{
{
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
constexpr
index_t
MPerThread
=
ThreadMatrixC
::
NRow
()
;
constexpr
index_t
MPerThread
=
ThreadMatrixC
::
GetLengths
()[
0
]
;
constexpr
index_t
NPerThread
=
ThreadMatrixC
::
NCol
()
;
constexpr
index_t
NPerThread
=
ThreadMatrixC
::
GetLengths
()[
1
]
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
...
...
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
View file @
cf9bd973
...
@@ -4,9 +4,9 @@
...
@@ -4,9 +4,9 @@
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_op.hpp"
#include "blockwise_gemm.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -177,28 +177,24 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -177,28 +177,24 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// register
constexpr
auto
a_k_m_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
a_k_m_block_desc
);
constexpr
auto
b_k_n_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
b_k_n_block_desc
);
// sanity check
// sanity check
static_assert
(
MPerBlock
%
(
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
)
==
0
&&
static_assert
(
MPerBlock
%
(
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
)
==
0
&&
NPerBlock
%
(
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
)
==
0
,
NPerBlock
%
(
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
)
==
0
,
"wrong!"
);
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
MPerBlock
/
(
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
);
constexpr
index_t
GemmMRepeat
=
MPerBlock
/
(
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
);
constexpr
index_t
GemmNRepeat
=
NPerBlock
/
(
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
);
constexpr
index_t
GemmNRepeat
=
NPerBlock
/
(
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_m0m1_n0n1_thread_
mtx_
desc
=
make_
ConstantMatrixD
escriptor_packed
(
constexpr
auto
c_m0m1_n0n1_thread_desc
=
make_
native_tensor_d
escriptor_packed
(
Number
<
GemmMRepeat
*
MPerThread
>
{},
Number
<
GemmNRepeat
*
NPerThread
>
{});
Sequence
<
GemmMRepeat
*
MPerThread
,
GemmNRepeat
*
NPerThread
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
BlockSize
,
decltype
(
a_k_m_block_
mtx_
desc
),
decltype
(
a_k_m_block_desc
),
decltype
(
b_k_n_block_
mtx_
desc
),
decltype
(
b_k_n_block_desc
),
decltype
(
c_m0m1_n0n1_thread_
mtx_
desc
),
decltype
(
c_m0m1_n0n1_thread_desc
),
MPerThread
,
MPerThread
,
NPerThread
,
NPerThread
,
MLevel0Cluster
,
MLevel0Cluster
,
...
@@ -220,10 +216,10 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -220,10 +216,10 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
Float
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space
;
Float
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space
;
// register allocation for output
// register allocation for output
AccFloat
p_c_thread
[
c_m0m1_n0n1_thread_
mtx_
desc
.
GetElementSpace
()];
AccFloat
p_c_thread
[
c_m0m1_n0n1_thread_desc
.
GetElementSpace
()];
// zero out threadwise output
// zero out threadwise output
threadwise_
matrix
_set_zero
(
c_m0m1_n0n1_thread_
mtx_
desc
,
p_c_thread
);
threadwise_
generic_tensor
_set_zero
(
c_m0m1_n0n1_thread_desc
,
p_c_thread
);
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
...
...
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
View file @
cf9bd973
...
@@ -2,59 +2,11 @@
...
@@ -2,59 +2,11 @@
#define CK_THREADWISE_GEMM_HPP
#define CK_THREADWISE_GEMM_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "
ConstantMatrixD
escriptor.hpp"
#include "
tensor_d
escriptor.hpp"
#include "math.hpp"
#include "math.hpp"
namespace
ck
{
namespace
ck
{
template
<
typename
Float
,
class
Matrix
>
__device__
void
threadwise_matrix_set_zero
(
Matrix
,
Float
*
__restrict__
p_thread
)
{
for
(
index_t
i
=
0
;
i
<
Matrix
::
NRow
();
++
i
)
{
for
(
index_t
j
=
0
;
j
<
Matrix
::
NCol
();
++
j
)
{
const
index_t
id
=
Matrix
::
CalculateOffset
(
i
,
j
);
p_thread
[
id
]
=
Float
(
0
);
}
}
}
template
<
typename
SrcMatrix
,
typename
DstMatrix
,
index_t
NSliceRow
,
index_t
NSliceCol
,
index_t
DataPerAccess
>
struct
ThreadwiseMatrixSliceCopy
{
__device__
constexpr
ThreadwiseMatrixSliceCopy
()
{
static_assert
(
SrcMatrix
::
RowStride
()
%
DataPerAccess
==
0
&&
DstMatrix
::
RowStride
()
%
DataPerAccess
==
0
,
"wrong! wrong alignment"
);
static_assert
(
NSliceCol
%
DataPerAccess
==
0
,
"wrong! should be NSliceCol % DataPerAccess == 0"
);
}
template
<
typename
Data
>
__device__
static
void
Run
(
const
Data
*
p_src
,
Data
*
p_dst
)
{
using
vector_t
=
typename
vector_type
<
Data
,
DataPerAccess
>::
MemoryType
;
for
(
index_t
i
=
0
;
i
<
NSliceRow
;
++
i
)
{
for
(
index_t
j
=
0
;
j
<
NSliceCol
;
j
+=
DataPerAccess
)
{
const
index_t
src_index
=
SrcMatrix
::
CalculateOffset
(
i
,
j
);
const
index_t
dst_index
=
DstMatrix
::
CalculateOffset
(
i
,
j
);
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
]);
}
}
}
};
// C += transpose(A) * B
// C += transpose(A) * B
// Element of matrix can be vectorized data
// Element of matrix can be vectorized data
template
<
typename
MatrixA
,
typename
MatrixB
,
typename
MatrixC
>
template
<
typename
MatrixA
,
typename
MatrixB
,
typename
MatrixC
>
...
@@ -62,17 +14,18 @@ struct ThreadwiseGemmTransANormalBNormalC
...
@@ -62,17 +14,18 @@ struct ThreadwiseGemmTransANormalBNormalC
{
{
__device__
constexpr
ThreadwiseGemmTransANormalBNormalC
()
__device__
constexpr
ThreadwiseGemmTransANormalBNormalC
()
{
{
static_assert
(
MatrixA
::
NRow
()
==
MatrixB
::
NRow
()
&&
MatrixA
::
NCol
()
==
MatrixC
::
NRow
()
&&
static_assert
(
MatrixA
::
GetLengths
()[
0
]
==
MatrixB
::
GetLengths
()[
0
]
&&
MatrixB
::
NCol
()
==
MatrixC
::
NCol
(),
MatrixA
::
GetlLengths
()[
1
]
==
MatrixC
::
GetLengths
()[
0
]
&&
MatrixB
::
GetLengths
()[
1
]
==
MatrixC
::
GetLenths
()[
1
],
"wrong!"
);
"wrong!"
);
}
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run_source
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
__device__
static
void
Run_source
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
{
{
constexpr
index_t
M
=
MatrixC
::
NRow
()
;
constexpr
index_t
M
=
MatrixC
::
GetLengths
()[
0
]
;
constexpr
index_t
N
=
MatrixC
::
NCol
()
;
constexpr
index_t
N
=
MatrixC
::
GetLengths
()[
1
]
;
constexpr
index_t
K
=
MatrixA
::
NRow
()
;
// A is transposed
constexpr
index_t
K
=
MatrixA
::
GetLengths
()[
0
]
;
// A is transposed
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
{
...
@@ -80,9 +33,9 @@ struct ThreadwiseGemmTransANormalBNormalC
...
@@ -80,9 +33,9 @@ struct ThreadwiseGemmTransANormalBNormalC
{
{
for
(
index_t
n
=
0
;
n
<
N
;
++
n
)
for
(
index_t
n
=
0
;
n
<
N
;
++
n
)
{
{
const
index_t
aindex
=
MatrixA
::
CalculateOffset
(
k
,
m
);
// A is transposed
const
index_t
aindex
=
MatrixA
::
CalculateOffset
(
{
k
,
m
}
);
// A is transposed
const
index_t
bindex
=
MatrixB
::
CalculateOffset
(
k
,
n
);
const
index_t
bindex
=
MatrixB
::
CalculateOffset
(
{
k
,
n
}
);
const
index_t
cindex
=
MatrixC
::
CalculateOffset
(
m
,
n
);
const
index_t
cindex
=
MatrixC
::
CalculateOffset
(
{
m
,
n
}
);
p_c
[
cindex
]
+=
p_c
[
cindex
]
+=
inner_product_with_conversion
<
FloatC
>
{}(
p_a
[
aindex
],
p_b
[
bindex
]);
inner_product_with_conversion
<
FloatC
>
{}(
p_a
[
aindex
],
p_b
[
bindex
]);
...
@@ -95,9 +48,9 @@ struct ThreadwiseGemmTransANormalBNormalC
...
@@ -95,9 +48,9 @@ struct ThreadwiseGemmTransANormalBNormalC
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run_amd_asm
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
__device__
static
void
Run_amd_asm
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
{
{
constexpr
index_t
M
=
MatrixC
::
NRow
()
;
constexpr
index_t
M
=
MatrixC
::
GetLengths
()[
0
]
;
constexpr
index_t
N
=
MatrixC
::
NCol
()
;
constexpr
index_t
N
=
MatrixC
::
GetLengths
()[
1
]
;
constexpr
index_t
K
=
MatrixA
::
NRow
()
;
// A is transposed
constexpr
index_t
K
=
MatrixA
::
GetLengths
()[
0
]
;
// A is transposed
static_assert
(
N
==
4
||
N
==
2
,
"wrong! this config not supported by asm yet"
);
static_assert
(
N
==
4
||
N
==
2
,
"wrong! this config not supported by asm yet"
);
...
@@ -108,26 +61,26 @@ struct ThreadwiseGemmTransANormalBNormalC
...
@@ -108,26 +61,26 @@ struct ThreadwiseGemmTransANormalBNormalC
const
index_t
aindex
=
MatrixA
::
CalculateOffset
(
k
,
m
);
// A is transposed
const
index_t
aindex
=
MatrixA
::
CalculateOffset
(
k
,
m
);
// A is transposed
static_if
<
N
==
2
>
{}([
&
](
auto
)
{
static_if
<
N
==
2
>
{}([
&
](
auto
)
{
const
index_t
bindex_0
=
MatrixB
::
CalculateOffset
(
k
,
0
);
const
index_t
bindex_0
=
MatrixB
::
CalculateOffset
(
{
k
,
0
}
);
const
index_t
bindex_1
=
MatrixB
::
CalculateOffset
(
k
,
1
);
const
index_t
bindex_1
=
MatrixB
::
CalculateOffset
(
{
k
,
1
}
);
const
index_t
cindex_0
=
MatrixC
::
CalculateOffset
(
m
,
0
);
const
index_t
cindex_0
=
MatrixC
::
CalculateOffset
(
{
m
,
0
}
);
const
index_t
cindex_1
=
MatrixC
::
CalculateOffset
(
m
,
1
);
const
index_t
cindex_1
=
MatrixC
::
CalculateOffset
(
{
m
,
1
}
);
amd_assembly_outer_product_1x2
(
amd_assembly_outer_product_1x2
(
p_a
[
aindex
],
p_b
[
bindex_0
],
p_b
[
bindex_1
],
p_c
[
cindex_0
],
p_c
[
cindex_1
]);
p_a
[
aindex
],
p_b
[
bindex_0
],
p_b
[
bindex_1
],
p_c
[
cindex_0
],
p_c
[
cindex_1
]);
});
});
static_if
<
N
==
4
>
{}([
&
](
auto
)
{
static_if
<
N
==
4
>
{}([
&
](
auto
)
{
const
index_t
bindex_0
=
MatrixB
::
CalculateOffset
(
k
,
0
);
const
index_t
bindex_0
=
MatrixB
::
CalculateOffset
(
{
k
,
0
}
);
const
index_t
bindex_1
=
MatrixB
::
CalculateOffset
(
k
,
1
);
const
index_t
bindex_1
=
MatrixB
::
CalculateOffset
(
{
k
,
1
}
);
const
index_t
bindex_2
=
MatrixB
::
CalculateOffset
(
k
,
2
);
const
index_t
bindex_2
=
MatrixB
::
CalculateOffset
(
{
k
,
2
}
);
const
index_t
bindex_3
=
MatrixB
::
CalculateOffset
(
k
,
3
);
const
index_t
bindex_3
=
MatrixB
::
CalculateOffset
(
{
k
,
3
}
);
const
index_t
cindex_0
=
MatrixC
::
CalculateOffset
(
m
,
0
);
const
index_t
cindex_0
=
MatrixC
::
CalculateOffset
(
{
m
,
0
}
);
const
index_t
cindex_1
=
MatrixC
::
CalculateOffset
(
m
,
1
);
const
index_t
cindex_1
=
MatrixC
::
CalculateOffset
(
{
m
,
1
}
);
const
index_t
cindex_2
=
MatrixC
::
CalculateOffset
(
m
,
2
);
const
index_t
cindex_2
=
MatrixC
::
CalculateOffset
(
{
m
,
2
}
);
const
index_t
cindex_3
=
MatrixC
::
CalculateOffset
(
m
,
3
);
const
index_t
cindex_3
=
MatrixC
::
CalculateOffset
(
{
m
,
3
}
);
amd_assembly_outer_product_1x4
(
p_a
[
aindex
],
amd_assembly_outer_product_1x4
(
p_a
[
aindex
],
p_b
[
bindex_0
],
p_b
[
bindex_0
],
...
...
composable_kernel/include/tensor_operation/threadwise_generic_tensor_op.hpp
View file @
cf9bd973
...
@@ -2,18 +2,14 @@
...
@@ -2,18 +2,14 @@
#define CK_THREADWISE_GENERIC_TENSOR_OP_HPP
#define CK_THREADWISE_GENERIC_TENSOR_OP_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "tensor_descriptor.hpp"
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
namespace
ck
{
namespace
ck
{
template
<
class
Float
,
class
TDesc
>
template
<
class
Float
,
class
T
ensor
Desc
>
__device__
void
threadwise_generic_tensor_set_zero
(
TDesc
,
Float
*
__restrict__
p
)
__device__
void
threadwise_generic_tensor_set_zero
(
T
ensor
Desc
,
Float
*
__restrict__
p
)
{
{
static_ford
<
decltype
(
TDesc
::
GetLengths
())
>
{}([
&
](
auto
multi_id
)
{
ford
<
decltype
(
TensorDesc
::
GetLengths
())
>
{}(
constexpr
index_t
offset
=
TDesc
::
GetOffsetFromMultiIndex
(
multi_id
);
[
&
](
auto
idx
)
{
p
[
TensorDesc
::
CalculateOffset
(
idx
)]
=
static_cast
<
Float
>
(
0
);
});
p
[
offset
]
=
static_cast
<
Float
>
(
0
);
});
}
}
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/tensor_operation/threadwise_generic_tensor_op_deprecated.hpp
0 → 100644
View file @
cf9bd973
#ifndef CK_THREADWISE_GENERIC_TENSOR_OP_DEPRECATED_HPP
#define CK_THREADWISE_GENERIC_TENSOR_OP_DEPRECATED_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
namespace
ck
{
template
<
class
Float
,
class
TDesc
>
__device__
void
threadwise_generic_tensor_set_zero
(
TDesc
,
Float
*
__restrict__
p
)
{
static_ford
<
decltype
(
TDesc
::
GetLengths
())
>
{}([
&
](
auto
multi_id
)
{
constexpr
index_t
offset
=
TDesc
::
GetOffsetFromMultiIndex
(
multi_id
);
p
[
offset
]
=
static_cast
<
Float
>
(
0
);
});
}
}
// namespace ck
#endif
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
cf9bd973
...
@@ -53,7 +53,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -53,7 +53,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if
0
#if
1
// BlockSize = 256, GemmKPerBlock = 8
// BlockSize = 256, GemmKPerBlock = 8
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
...
driver/src/conv_driver.cpp
View file @
cf9bd973
...
@@ -18,18 +18,18 @@
...
@@ -18,18 +18,18 @@
//#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
//#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated.hpp"
//
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
//
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated.hpp"
//
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
using
namespace
ck
;
using
namespace
ck
;
#if
1
#if
0
// 1x1
// 1x1
constexpr index_t N = 64;
constexpr index_t N = 64;
constexpr index_t C = 64;
constexpr index_t C = 64;
...
@@ -59,7 +59,7 @@ int main(int argc, char* argv[])
...
@@ -59,7 +59,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
3
>
;
using
LeftPads
=
Sequence
<
0
,
3
>
;
using
RightPads
=
Sequence
<
0
,
3
>
;
using
RightPads
=
Sequence
<
0
,
3
>
;
#elif
0
#elif
1
// 3x3, 34x34
// 3x3, 34x34
constexpr
index_t
N
=
64
;
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
constexpr
index_t
C
=
256
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment