Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
02d23347
"examples/dreambooth/train_dreambooth_lora.py" did not exist on "8874027efc2619df2e8b2b567396573b5e9a2ee6"
Commit
02d23347
authored
May 31, 2021
by
Chao Liu
Browse files
overhauling fwd-v4r4
parent
318db82b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
130 additions
and
139 deletions
+130
-139
composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
...osable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
+128
-112
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
...nvolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
+2
-27
No files found.
composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
View file @
02d23347
...
@@ -13,10 +13,9 @@ template <index_t BlockSize,
...
@@ -13,10 +13,9 @@ template <index_t BlockSize,
typename
FloatAcc
,
typename
FloatAcc
,
typename
FloatC
,
typename
FloatC
,
InMemoryDataOperation
CGlobalMemoryDataOperation
,
InMemoryDataOperation
CGlobalMemoryDataOperation
,
typename
AGlobalDesc
,
typename
AKMGridDesc
,
typename
BGlobalDesc
,
typename
BKNGridDesc
,
typename
CGlobalDesc
,
typename
CMNGridDesc
,
typename
CBlockClusterDesc
,
index_t
MPerBlock
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
KPerBlock
,
...
@@ -46,23 +45,22 @@ template <index_t BlockSize,
...
@@ -46,23 +45,22 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGlobalIteratorHacks
,
typename
AGridIteratorHacks
,
typename
BGlobalIteratorHacks
,
typename
BGridIteratorHacks
,
typename
CGlobalIteratorHacks
,
typename
CGridIteratorHacks
,
typename
AGlobalMoveSliceWindowIteratorHacks
,
typename
AGridMoveSliceWindowIteratorHacks
,
typename
BGlobalMoveSliceWindowIteratorHacks
>
typename
BGridMoveSliceWindowIteratorHacks
>
__host__
float
launch_kernel_dynamic_gemm_v1r2
(
const
FloatAB
*
p_a_global
,
__host__
float
launch_kernel_dynamic_gemm_v1r2
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_global
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_global
,
FloatC
*
p_c_grid
,
const
AGlobalDesc
&
a_k_m_global_desc
,
const
AKMGridDesc
&
a_k_m_grid_desc
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
BKNGridDesc
&
b_k_n_grid_desc
,
const
CGlobalDesc
&
c_m0_m1_n0_n1_global_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
AGridIteratorHacks
,
AGlobalIteratorHacks
,
BGridIteratorHacks
,
BGlobalIteratorHacks
,
CGridIteratorHacks
,
CGlobalIteratorHacks
,
AGridMoveSliceWindowIteratorHacks
,
AGlobalMoveSliceWindowIteratorHacks
,
BGridMoveSliceWindowIteratorHacks
,
BGlobalMoveSliceWindowIteratorHacks
,
index_t
nrepeat
)
index_t
nrepeat
)
{
{
...
@@ -71,23 +69,41 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
...
@@ -71,23 +69,41 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
M
=
a_k_m_g
lobal
_desc
.
GetLength
(
I1
);
const
auto
M
=
a_k_m_g
rid
_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_g
lobal
_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_g
rid
_desc
.
GetLength
(
I1
);
const
auto
K
=
a_k_m_g
lobal
_desc
.
GetLength
(
I0
);
const
auto
K
=
a_k_m_g
rid
_desc
.
GetLength
(
I0
);
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
{
{
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
}
const
expr
auto
M1
=
Number
<
M1PerThread
*
M1N1ThreadClusterM11
*
M1N1ThreadClusterM10
>
{};
const
auto
M1
=
Number
<
M1PerThread
*
M1N1ThreadClusterM11
*
M1N1ThreadClusterM10
>
{};
const
expr
auto
N1
=
Number
<
N1PerThread
*
M1N1ThreadClusterN11
*
M1N1ThreadClusterN10
>
{};
const
auto
N1
=
Number
<
N1PerThread
*
M1N1ThreadClusterN11
*
M1N1ThreadClusterN10
>
{};
if
(
!
(
MPerBlock
%
M1
==
0
&&
NPerBlock
%
N1
==
0
))
if
(
!
(
MPerBlock
%
M1
==
0
&&
NPerBlock
%
N1
==
0
))
{
{
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
}
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
c_m0_m1_n0_n1_grid_desc
=
transform_dynamic_tensor_descriptor
(
c_m_n_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M1
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
using
CM0M1N0N1GridDesc
=
decltype
(
c_m0_m1_n0_n1_grid_desc
);
// out_gemm_block_cluster_desc
const
auto
c_block_cluster_desc
=
make_cluster_descriptor_v2
(
make_tuple
(
M
/
Number
<
MPerBlock
>
{},
N
/
Number
<
NPerBlock
>
{}));
using
CBlockClusterDesc
=
decltype
(
c_block_cluster_desc
);
// GEMM
// GEMM
using
gridwise_gemm
=
using
gridwise_gemm
=
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
<
BlockSize
,
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
<
BlockSize
,
...
@@ -95,9 +111,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
...
@@ -95,9 +111,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAcc
,
FloatAcc
,
FloatC
,
FloatC
,
CGlobalMemoryDataOperation
,
CGlobalMemoryDataOperation
,
A
Global
Desc
,
A
KMGrid
Desc
,
B
Global
Desc
,
B
KNGrid
Desc
,
C
Global
Desc
,
C
M0M1N0N1Grid
Desc
,
CBlockClusterDesc
,
CBlockClusterDesc
,
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
...
@@ -128,11 +144,11 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
...
@@ -128,11 +144,11 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AG
lobal
IteratorHacks
,
AG
rid
IteratorHacks
,
BG
lobal
IteratorHacks
,
BG
rid
IteratorHacks
,
CG
lobal
IteratorHacks
,
CG
rid
IteratorHacks
,
AG
lobal
MoveSliceWindowIteratorHacks
,
AG
rid
MoveSliceWindowIteratorHacks
,
BG
lobal
MoveSliceWindowIteratorHacks
>
;
BG
rid
MoveSliceWindowIteratorHacks
>
;
const
auto
GridSize
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
const
auto
GridSize
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
...
@@ -149,9 +165,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
...
@@ -149,9 +165,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
A
Global
Desc
>
,
remove_reference_t
<
A
KMGrid
Desc
>
,
remove_reference_t
<
B
Global
Desc
>
,
remove_reference_t
<
B
KNGrid
Desc
>
,
remove_reference_t
<
C
Global
Desc
>
,
remove_reference_t
<
C
M0M1N0N1Grid
Desc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
true
,
true
,
true
>
;
true
>
;
...
@@ -162,12 +178,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
...
@@ -162,12 +178,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
p_a_g
lobal
,
p_a_g
rid
,
p_b_g
lobal
,
p_b_g
rid
,
p_c_g
lobal
,
p_c_g
rid
,
a_k_m_g
lobal
_desc
,
a_k_m_g
rid
_desc
,
b_k_n_g
lobal
_desc
,
b_k_n_g
rid
_desc
,
c_m0_m1_n0_n1_g
lobal
_desc
,
c_m0_m1_n0_n1_g
rid
_desc
,
c_block_cluster_desc
);
c_block_cluster_desc
);
}
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
...
@@ -176,9 +192,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
...
@@ -176,9 +192,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
A
Global
Desc
>
,
remove_reference_t
<
A
KMGrid
Desc
>
,
remove_reference_t
<
B
Global
Desc
>
,
remove_reference_t
<
B
KNGrid
Desc
>
,
remove_reference_t
<
C
Global
Desc
>
,
remove_reference_t
<
C
M0M1N0N1Grid
Desc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
true
,
true
,
false
>
;
false
>
;
...
@@ -189,12 +205,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
...
@@ -189,12 +205,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
p_a_g
lobal
,
p_a_g
rid
,
p_b_g
lobal
,
p_b_g
rid
,
p_c_g
lobal
,
p_c_g
rid
,
a_k_m_g
lobal
_desc
,
a_k_m_g
rid
_desc
,
b_k_n_g
lobal
_desc
,
b_k_n_g
rid
_desc
,
c_m0_m1_n0_n1_g
lobal
_desc
,
c_m0_m1_n0_n1_g
rid
_desc
,
c_block_cluster_desc
);
c_block_cluster_desc
);
}
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
...
@@ -203,9 +219,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
...
@@ -203,9 +219,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
A
Global
Desc
>
,
remove_reference_t
<
A
KMGrid
Desc
>
,
remove_reference_t
<
B
Global
Desc
>
,
remove_reference_t
<
B
KNGrid
Desc
>
,
remove_reference_t
<
C
Global
Desc
>
,
remove_reference_t
<
C
M0M1N0N1Grid
Desc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
false
,
false
,
true
>
;
true
>
;
...
@@ -216,12 +232,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
...
@@ -216,12 +232,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
p_a_g
lobal
,
p_a_g
rid
,
p_b_g
lobal
,
p_b_g
rid
,
p_c_g
lobal
,
p_c_g
rid
,
a_k_m_g
lobal
_desc
,
a_k_m_g
rid
_desc
,
b_k_n_g
lobal
_desc
,
b_k_n_g
rid
_desc
,
c_m0_m1_n0_n1_g
lobal
_desc
,
c_m0_m1_n0_n1_g
rid
_desc
,
c_block_cluster_desc
);
c_block_cluster_desc
);
}
}
else
else
...
@@ -230,9 +246,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
...
@@ -230,9 +246,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
A
Global
Desc
>
,
remove_reference_t
<
A
KMGrid
Desc
>
,
remove_reference_t
<
B
Global
Desc
>
,
remove_reference_t
<
B
KNGrid
Desc
>
,
remove_reference_t
<
C
Global
Desc
>
,
remove_reference_t
<
C
M0M1N0N1Grid
Desc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
false
,
false
,
false
>
;
false
>
;
...
@@ -243,25 +259,25 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
...
@@ -243,25 +259,25 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
p_a_g
lobal
,
p_a_g
rid
,
p_b_g
lobal
,
p_b_g
rid
,
p_c_g
lobal
,
p_c_g
rid
,
a_k_m_g
lobal
_desc
,
a_k_m_g
rid
_desc
,
b_k_n_g
lobal
_desc
,
b_k_n_g
rid
_desc
,
c_m0_m1_n0_n1_g
lobal
_desc
,
c_m0_m1_n0_n1_g
rid
_desc
,
c_block_cluster_desc
);
c_block_cluster_desc
);
}
}
return
ave_time
;
return
ave_time
;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem
a_k_m_g
lobal
_desc_device_buf
(
sizeof
(
A
Global
Desc
));
DeviceMem
a_k_m_g
rid
_desc_device_buf
(
sizeof
(
A
KMGrid
Desc
));
DeviceMem
b_k_n_g
lobal
_desc_device_buf
(
sizeof
(
B
Global
Desc
));
DeviceMem
b_k_n_g
rid
_desc_device_buf
(
sizeof
(
B
KNGrid
Desc
));
DeviceMem
c_m0_m1_n0_n1_g
lobal
_desc_device_buf
(
sizeof
(
C
Global
Desc
));
DeviceMem
c_m0_m1_n0_n1_g
rid
_desc_device_buf
(
sizeof
(
C
M0M1N0N1Grid
Desc
));
DeviceMem
c_block_cluster_desc_device_buf
(
sizeof
(
c_block_cluster_desc
));
DeviceMem
c_block_cluster_desc_device_buf
(
sizeof
(
c_block_cluster_desc
));
a_k_m_g
lobal
_desc_device_buf
.
ToDevice
(
&
a_k_m_g
lobal
_desc
);
a_k_m_g
rid
_desc_device_buf
.
ToDevice
(
&
a_k_m_g
rid
_desc
);
b_k_n_g
lobal
_desc_device_buf
.
ToDevice
(
&
b_k_n_g
lobal
_desc
);
b_k_n_g
rid
_desc_device_buf
.
ToDevice
(
&
b_k_n_g
rid
_desc
);
c_m0_m1_n0_n1_g
lobal
_desc_device_buf
.
ToDevice
(
&
c_m0_m1_n0_n1_g
lobal
_desc
);
c_m0_m1_n0_n1_g
rid
_desc_device_buf
.
ToDevice
(
&
c_m0_m1_n0_n1_g
rid
_desc
);
c_block_cluster_desc_device_buf
.
ToDevice
(
&
c_block_cluster_desc
);
c_block_cluster_desc_device_buf
.
ToDevice
(
&
c_block_cluster_desc
);
float
ave_time
=
0
;
float
ave_time
=
0
;
...
@@ -272,9 +288,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
...
@@ -272,9 +288,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
A
Global
Desc
>
,
remove_reference_t
<
A
KMGrid
Desc
>
,
remove_reference_t
<
B
Global
Desc
>
,
remove_reference_t
<
B
KNGrid
Desc
>
,
remove_reference_t
<
C
Global
Desc
>
,
remove_reference_t
<
C
M0M1N0N1Grid
Desc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
true
,
true
,
true
>
;
true
>
;
...
@@ -286,12 +302,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
...
@@ -286,12 +302,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
p_a_g
lobal
,
p_a_g
rid
,
p_b_g
lobal
,
p_b_g
rid
,
p_c_g
lobal
,
p_c_g
rid
,
(
void
__CONSTANT__
*
)
a_k_m_g
lobal
_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
a_k_m_g
rid
_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_g
lobal
_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_g
rid
_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_g
lobal
_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_g
rid
_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
...
@@ -300,9 +316,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
...
@@ -300,9 +316,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
A
Global
Desc
>
,
remove_reference_t
<
A
KMGrid
Desc
>
,
remove_reference_t
<
B
Global
Desc
>
,
remove_reference_t
<
B
KNGrid
Desc
>
,
remove_reference_t
<
C
Global
Desc
>
,
remove_reference_t
<
C
M0M1N0N1Grid
Desc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
true
,
true
,
false
>
;
false
>
;
...
@@ -314,12 +330,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
...
@@ -314,12 +330,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
p_a_g
lobal
,
p_a_g
rid
,
p_b_g
lobal
,
p_b_g
rid
,
p_c_g
lobal
,
p_c_g
rid
,
(
void
__CONSTANT__
*
)
a_k_m_g
lobal
_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
a_k_m_g
rid
_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_g
lobal
_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_g
rid
_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_g
lobal
_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_g
rid
_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
...
@@ -328,9 +344,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
...
@@ -328,9 +344,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
A
Global
Desc
>
,
remove_reference_t
<
A
KMGrid
Desc
>
,
remove_reference_t
<
B
Global
Desc
>
,
remove_reference_t
<
B
KNGrid
Desc
>
,
remove_reference_t
<
C
Global
Desc
>
,
remove_reference_t
<
C
M0M1N0N1Grid
Desc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
false
,
false
,
true
>
;
true
>
;
...
@@ -342,12 +358,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
...
@@ -342,12 +358,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
p_a_g
lobal
,
p_a_g
rid
,
p_b_g
lobal
,
p_b_g
rid
,
p_c_g
lobal
,
p_c_g
rid
,
(
void
__CONSTANT__
*
)
a_k_m_g
lobal
_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
a_k_m_g
rid
_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_g
lobal
_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_g
rid
_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_g
lobal
_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_g
rid
_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
}
else
else
...
@@ -356,9 +372,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
...
@@ -356,9 +372,9 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
A
Global
Desc
>
,
remove_reference_t
<
A
KMGrid
Desc
>
,
remove_reference_t
<
B
Global
Desc
>
,
remove_reference_t
<
B
KNGrid
Desc
>
,
remove_reference_t
<
C
Global
Desc
>
,
remove_reference_t
<
C
M0M1N0N1Grid
Desc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
false
,
false
,
false
>
;
false
>
;
...
@@ -370,12 +386,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
...
@@ -370,12 +386,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
p_a_g
lobal
,
p_a_g
rid
,
p_b_g
lobal
,
p_b_g
rid
,
p_c_g
lobal
,
p_c_g
rid
,
(
void
__CONSTANT__
*
)
a_k_m_g
lobal
_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
a_k_m_g
rid
_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_g
lobal
_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_g
rid
_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_g
lobal
_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_g
rid
_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
}
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw.hpp
View file @
02d23347
...
@@ -482,29 +482,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
...
@@ -482,29 +482,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
const
auto
in_gemmk_gemmn_grid_desc
=
descs
[
I1
];
const
auto
in_gemmk_gemmn_grid_desc
=
descs
[
I1
];
const
auto
out_gemmm_gemmn_grid_desc
=
descs
[
I2
];
const
auto
out_gemmm_gemmn_grid_desc
=
descs
[
I2
];
const
auto
GemmM
=
out_gemmm_gemmn_grid_desc
.
GetLength
(
I0
);
const
auto
GemmN
=
out_gemmm_gemmn_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK
=
wei_gemmk_gemmm_grid_desc
.
GetLength
(
I0
);
constexpr
index_t
GemmM1
=
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
GemmN1
=
GemmNPerThread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
);
const
auto
GemmM0
=
GemmM
/
Number
<
GemmM1
>
{};
const
auto
GemmN0
=
GemmN
/
Number
<
GemmN1
>
{};
const
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_grid_desc
=
transform_dynamic_tensor_descriptor
(
out_gemmm_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmM0
,
GemmM1
)),
make_unmerge_transform
(
make_tuple
(
GemmN0
,
GemmN1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// out_gemm_block_cluster_desc
const
auto
out_gemm_block_cluster_desc
=
make_cluster_descriptor_v2
(
make_tuple
(
GemmM
/
Number
<
GemmMPerBlock
>
{},
GemmN
/
Number
<
GemmNPerBlock
>
{}));
// hack to control index calculation when iterating over wei_gemmk_gemmm_grid tensor
// hack to control index calculation when iterating over wei_gemmk_gemmm_grid tensor
constexpr
auto
wei_gemmk_gemmm_grid_iterator_hacks
=
constexpr
auto
wei_gemmk_gemmm_grid_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
...
@@ -543,8 +520,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
...
@@ -543,8 +520,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
InMemoryDataOperation
::
Set
,
InMemoryDataOperation
::
Set
,
decltype
(
wei_gemmk_gemmm_grid_desc
),
decltype
(
wei_gemmk_gemmm_grid_desc
),
decltype
(
in_gemmk_gemmn_grid_desc
),
decltype
(
in_gemmk_gemmn_grid_desc
),
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_grid_desc
),
decltype
(
out_gemmm_gemmn_grid_desc
),
decltype
(
out_gemm_block_cluster_desc
),
GemmMPerBlock
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmKPerBlock
,
...
@@ -587,8 +563,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
...
@@ -587,8 +563,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
wei_gemmk_gemmm_grid_desc
,
wei_gemmk_gemmm_grid_desc
,
in_gemmk_gemmn_grid_desc
,
in_gemmk_gemmn_grid_desc
,
out_gemmm0_gemmm1_gemmn0_gemmn1_grid_desc
,
out_gemmm_gemmn_grid_desc
,
out_gemm_block_cluster_desc
,
wei_gemmk_gemmm_grid_iterator_hacks
,
wei_gemmk_gemmm_grid_iterator_hacks
,
in_gemmk_gemmn_grid_iterator_hacks
,
in_gemmk_gemmn_grid_iterator_hacks
,
out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks
,
out_gemmm0_gemmm1_gemmn0_gemmn1_grid_iterator_hacks
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment