Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
aac345ab
Commit
aac345ab
authored
May 12, 2021
by
Chao Liu
Browse files
refactor gridwise gemm
parent
9d8b39a7
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
124 additions
and
127 deletions
+124
-127
composable_kernel/include/driver/driver_dynamic_gemm_v1.hpp
composable_kernel/include/driver/driver_dynamic_gemm_v1.hpp
+65
-73
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+59
-54
No files found.
composable_kernel/include/driver/driver_dynamic_gemm_v1.hpp
View file @
aac345ab
...
...
@@ -146,16 +146,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
BGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CBlockClusterDesc
>
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>
>
;
true
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
...
...
@@ -163,28 +163,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
0
,
0
,
a_k_m_global_desc
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
c_block_cluster_desc
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
);
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
BGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CBlockClusterDesc
>
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
false
>
>
;
true
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
...
...
@@ -192,28 +190,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
0
,
0
,
a_k_m_global_desc
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
c_block_cluster_desc
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
);
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
BGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CBlockClusterDesc
>
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
true
>
>
;
false
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
...
...
@@ -221,28 +217,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
0
,
0
,
a_k_m_global_desc
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
c_block_cluster_desc
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
);
}
else
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
BGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CBlockClusterDesc
>
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>
>
;
false
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
...
...
@@ -250,15 +244,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
0
,
0
,
a_k_m_global_desc
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
c_block_cluster_desc
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
);
}
return
ave_time
;
...
...
@@ -277,13 +269,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
BGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CBlockClusterDesc
>
,
true
,
true
>
;
...
...
@@ -295,23 +287,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
0
,
0
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
p_a_global
,
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
p_b_global
,
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
p_c_global
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
remove_reference_t
<
AGlobalDesc
>
,
FloatAB
,
remove_reference_t
<
BGlobalDesc
>
,
FloatAB
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
true
,
false
>
;
...
...
@@ -323,23 +315,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
0
,
0
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
p_a_global
,
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
p_b_global
,
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
p_c_global
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
remove_reference_t
<
AGlobalDesc
>
,
FloatAB
,
remove_reference_t
<
BGlobalDesc
>
,
FloatAB
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
false
,
true
>
;
...
...
@@ -351,23 +343,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
0
,
0
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
p_a_global
,
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
p_b_global
,
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
p_c_global
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
else
{
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
remove_reference_t
<
AGlobalDesc
>
,
FloatAB
,
remove_reference_t
<
BGlobalDesc
>
,
FloatAB
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
false
,
false
>
;
...
...
@@ -379,12 +371,12 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
0
,
0
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
p_a_global
,
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
p_b_global
,
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
p_c_global
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
aac345ab
...
...
@@ -14,12 +14,12 @@ namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template
<
typename
GridwiseGemm
,
typename
AGlobalDesc
,
typename
FloatA
,
typename
BGlobalDesc
,
typename
FloatB
,
typename
CGlobalDesc
,
typename
FloatC
,
typename
AGlobalDesc
,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
typename
CBlockClusterDesc
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
...
...
@@ -27,20 +27,21 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_dynamic_gemm_v1
(
const
AGlobalDesc
a_k_m_global_desc
,
const
FloatA
*
__restrict__
p_a_global
,
const
BGlobalDesc
b_k_n_global_desc
,
kernel_dynamic_gemm_v1
(
const
FloatA
*
__restrict__
p_a_global
,
const
FloatB
*
__restrict__
p_b_global
,
const
CGlobalDesc
c_m0_m1_n0_n1_global_desc
,
FloatC
*
__restrict__
p_c_global
,
const
AGlobalDesc
a_k_m_global_desc
,
const
BGlobalDesc
b_k_n_global_desc
,
const
CGlobalDesc
c_m0_m1_n0_n1_global_desc
,
const
CBlockClusterDesc
c_block_cluster_desc
)
{
GridwiseGemm
{}.
Run
(
a_k_m_global_desc
,
GridwiseGemm
::
Run
(
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
...
...
@@ -50,12 +51,12 @@ __global__ void
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template
<
typename
GridwiseGemm
,
typename
AGlobalDesc
,
typename
FloatA
,
typename
BGlobalDesc
,
typename
FloatB
,
typename
CGlobalDesc
,
typename
FloatC
,
typename
AGlobalDesc
,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
typename
CBlockClusterDesc
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
...
...
@@ -63,12 +64,12 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_dynamic_gemm_v1
(
const
void
__CONSTANT__
*
p_a_k_m_global_desc
,
const
FloatA
*
__restrict__
p_a_global
,
const
void
__CONSTANT__
*
p_b_k_n_global_desc
,
kernel_dynamic_gemm_v1
(
const
FloatA
*
__restrict__
p_a_global
,
const
FloatB
*
__restrict__
p_b_global
,
const
void
__CONSTANT__
*
p_c_m0_m1_n0_n1_global_desc
,
FloatC
*
__restrict__
p_c_global
,
const
void
__CONSTANT__
*
p_a_k_m_global_desc
,
const
void
__CONSTANT__
*
p_b_k_n_global_desc
,
const
void
__CONSTANT__
*
p_c_m0_m1_n0_n1_global_desc
,
const
void
__CONSTANT__
*
p_c_block_cluster_desc
)
{
// first cast void __CONSTANT__ void* to void*
...
...
@@ -84,12 +85,13 @@ __global__ void
const
auto
c_block_cluster_desc
=
*
reinterpret_cast
<
const
CBlockClusterDesc
*>
((
const
void
*
)
p_c_block_cluster_desc
);
GridwiseGemm
{}.
Run
(
a_k_m_global_desc
,
GridwiseGemm
::
Run
(
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
...
...
@@ -169,16 +171,17 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
}
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
&
a_k_m_global_desc
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_global
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
FloatAB
*
__restrict__
p_b_global
,
const
CGlobalDesc
&
c_m0_m1_n0_n1_global_desc
,
FloatC
*
__restrict__
p_c_global
,
const
AGlobalDesc
&
a_k_m_global_desc
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
CGlobalDesc
&
c_m0_m1_n0_n1_global_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
FloatAB
*
__restrict__
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -514,26 +517,28 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
}
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
&
a_k_m_global_desc
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_global
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
FloatAB
*
__restrict__
p_b_global
,
const
CGlobalDesc
&
c_m0_m1_n0_n1_global_desc
,
FloatC
*
__restrict__
p_c_global
,
const
AGlobalDesc
&
a_k_m_global_desc
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
CGlobalDesc
&
c_m0_m1_n0_n1_global_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
constexpr
index_t
shared_block_size
=
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
Run
(
a_k_m_global_desc
,
Run
(
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
,
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment