Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
be49a8c5
Commit
be49a8c5
authored
May 12, 2021
by
Jing Zhang
Browse files
merge master
parents
bcdc330d
71d6b19d
Changes
22
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
877 additions
and
629 deletions
+877
-629
composable_kernel/include/driver/driver_dynamic_gemm_v1.hpp
composable_kernel/include/driver/driver_dynamic_gemm_v1.hpp
+65
-73
composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v1.hpp
...e_kernel/include/driver/driver_dynamic_gemm_xdlops_v1.hpp
+185
-64
composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp
...orm_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp
+2
-3
composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp
...orm_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp
+2
-3
composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp
...sor_operation/blockwise_dynamic_tensor_slice_transfer.hpp
+6
-9
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
+10
-12
composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
+1
-3
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+103
-96
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
...nel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
+20
-19
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
...include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
+130
-76
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+121
-182
composable_kernel/include/utility/amd_buffer_addressing_v2.hpp
...sable_kernel/include/utility/amd_buffer_addressing_v2.hpp
+6
-6
composable_kernel/include/utility/buffer.hpp
composable_kernel/include/utility/buffer.hpp
+0
-72
composable_kernel/include/utility/common_header.hpp
composable_kernel/include/utility/common_header.hpp
+2
-1
composable_kernel/include/utility/config.amd.hpp.in
composable_kernel/include/utility/config.amd.hpp.in
+8
-2
composable_kernel/include/utility/dynamic_buffer.hpp
composable_kernel/include/utility/dynamic_buffer.hpp
+173
-0
composable_kernel/include/utility/static_buffer.hpp
composable_kernel/include/utility/static_buffer.hpp
+33
-0
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+2
-1
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
...convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
+2
-1
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
+6
-6
No files found.
composable_kernel/include/driver/driver_dynamic_gemm_v1.hpp
View file @
be49a8c5
...
@@ -146,16 +146,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
...
@@ -146,16 +146,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
AGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
integral_constant
<
bool
,
true
>
,
true
,
integral_constant
<
bool
,
true
>
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
...
@@ -163,28 +163,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
...
@@ -163,28 +163,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
a_k_m_global_desc
,
p_a_global
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_c_global
,
c_block_cluster_desc
,
a_k_m_global_desc
,
integral_constant
<
bool
,
true
>
{},
b_k_n_global_desc
,
integral_constant
<
bool
,
true
>
{});
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
);
}
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
AGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
integral_constant
<
bool
,
true
>
,
true
,
integral_constant
<
bool
,
false
>
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
...
@@ -192,28 +190,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
...
@@ -192,28 +190,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
a_k_m_global_desc
,
p_a_global
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_c_global
,
c_block_cluster_desc
,
a_k_m_global_desc
,
integral_constant
<
bool
,
true
>
{},
b_k_n_global_desc
,
integral_constant
<
bool
,
false
>
{});
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
);
}
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
AGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
integral_constant
<
bool
,
false
>
,
false
,
integral_constant
<
bool
,
true
>
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
...
@@ -221,28 +217,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
...
@@ -221,28 +217,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
a_k_m_global_desc
,
p_a_global
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_c_global
,
c_block_cluster_desc
,
a_k_m_global_desc
,
integral_constant
<
bool
,
false
>
{},
b_k_n_global_desc
,
integral_constant
<
bool
,
true
>
{});
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
);
}
}
else
else
{
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
AGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
integral_constant
<
bool
,
false
>
,
false
,
integral_constant
<
bool
,
false
>
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
...
@@ -250,15 +244,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
...
@@ -250,15 +244,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
a_k_m_global_desc
,
p_a_global
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_c_global
,
c_block_cluster_desc
,
a_k_m_global_desc
,
integral_constant
<
bool
,
false
>
{},
b_k_n_global_desc
,
integral_constant
<
bool
,
false
>
{});
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
);
}
}
return
ave_time
;
return
ave_time
;
...
@@ -277,13 +269,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
...
@@ -277,13 +269,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
AGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
true
,
true
,
true
>
;
true
>
;
...
@@ -295,23 +287,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
...
@@ -295,23 +287,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
p_a_global
,
p_a_global
,
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
p_b_global
,
p_b_global
,
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
p_c_global
,
p_c_global
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
remove_reference_t
<
AGlobalDesc
>
,
FloatAB
,
FloatAB
,
remove_reference_t
<
BGlobalDesc
>
,
FloatAB
,
FloatAB
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
true
,
true
,
false
>
;
false
>
;
...
@@ -323,23 +315,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
...
@@ -323,23 +315,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
p_a_global
,
p_a_global
,
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
p_b_global
,
p_b_global
,
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
p_c_global
,
p_c_global
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
remove_reference_t
<
AGlobalDesc
>
,
FloatAB
,
FloatAB
,
remove_reference_t
<
BGlobalDesc
>
,
FloatAB
,
FloatAB
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
false
,
false
,
true
>
;
true
>
;
...
@@ -351,23 +343,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
...
@@ -351,23 +343,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
p_a_global
,
p_a_global
,
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
p_b_global
,
p_b_global
,
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
p_c_global
,
p_c_global
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
}
else
else
{
{
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
remove_reference_t
<
AGlobalDesc
>
,
FloatAB
,
FloatAB
,
remove_reference_t
<
BGlobalDesc
>
,
FloatAB
,
FloatAB
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
false
,
false
,
false
>
;
false
>
;
...
@@ -379,12 +371,12 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
...
@@ -379,12 +371,12 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
p_a_global
,
p_a_global
,
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
p_b_global
,
p_b_global
,
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
p_c_global
,
p_c_global
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
}
...
...
composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v1.hpp
View file @
be49a8c5
...
@@ -141,20 +141,21 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
...
@@ -141,20 +141,21 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
const
bool
has_double_tail_k_block_loop
=
(
K
/
KPerBlock
)
%
2
==
0
;
const
bool
has_double_tail_k_block_loop
=
(
K
/
KPerBlock
)
%
2
==
0
;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float
ave_time
=
0
;
float
ave_time
=
0
;
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_xdlops_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
AGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
integral_constant
<
bool
,
true
>
,
true
,
integral_constant
<
bool
,
true
>
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
...
@@ -162,28 +163,26 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
...
@@ -162,28 +163,26 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
a_k_m_global_desc
,
p_a_global
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_c_global
,
c_block_cluster_desc
,
a_k_m_global_desc
,
integral_constant
<
bool
,
true
>
{},
b_k_n_global_desc
,
integral_constant
<
bool
,
true
>
{});
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
);
}
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_xdlops_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
AGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
integral_constant
<
bool
,
true
>
,
true
,
integral_constant
<
bool
,
false
>
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
...
@@ -191,28 +190,26 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
...
@@ -191,28 +190,26 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
a_k_m_global_desc
,
p_a_global
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_c_global
,
c_block_cluster_desc
,
a_k_m_global_desc
,
integral_constant
<
bool
,
true
>
{},
b_k_n_global_desc
,
integral_constant
<
bool
,
false
>
{});
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
);
}
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_xdlops_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
AGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
integral_constant
<
bool
,
false
>
,
false
,
integral_constant
<
bool
,
true
>
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
...
@@ -220,28 +217,26 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
...
@@ -220,28 +217,26 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
a_k_m_global_desc
,
p_a_global
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_c_global
,
c_block_cluster_desc
,
a_k_m_global_desc
,
integral_constant
<
bool
,
false
>
{},
b_k_n_global_desc
,
integral_constant
<
bool
,
true
>
{});
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
);
}
}
else
else
{
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_xdlops_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
AGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
integral_constant
<
bool
,
false
>
,
false
,
integral_constant
<
bool
,
false
>
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
...
@@ -249,18 +244,144 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
...
@@ -249,18 +244,144 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
a_k_m_global_desc
,
p_a_global
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
p_b_global
,
p_c_global
,
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
);
}
return
ave_time
;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem
a_k_m_global_desc_device_buf
(
sizeof
(
AGlobalDesc
));
DeviceMem
b_k_n_global_desc_device_buf
(
sizeof
(
BGlobalDesc
));
DeviceMem
c_m0_m1_n0_n1_global_desc_device_buf
(
sizeof
(
CGlobalDesc
));
DeviceMem
c_block_cluster_desc_device_buf
(
sizeof
(
c_block_cluster_desc
));
a_k_m_global_desc_device_buf
.
ToDevice
(
&
a_k_m_global_desc
);
b_k_n_global_desc_device_buf
.
ToDevice
(
&
b_k_n_global_desc
);
c_m0_m1_n0_n1_global_desc_device_buf
.
ToDevice
(
&
c_m0_m1_n0_n1_global_desc
);
c_block_cluster_desc_device_buf
.
ToDevice
(
&
c_block_cluster_desc
);
float
ave_time
=
0
;
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_dynamic_gemm_xdlops_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
true
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
p_a_global
,
p_b_global
,
p_c_global
,
p_c_global
,
c_block_cluster_desc
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
integral_constant
<
bool
,
false
>
{},
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
integral_constant
<
bool
,
false
>
{});
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_dynamic_gemm_xdlops_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
true
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
p_a_global
,
p_b_global
,
p_c_global
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_dynamic_gemm_xdlops_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
false
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
p_a_global
,
p_b_global
,
p_c_global
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
else
{
const
auto
kernel
=
kernel_dynamic_gemm_xdlops_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
false
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
p_a_global
,
p_b_global
,
p_c_global
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
}
return
ave_time
;
return
ave_time
;
#endif
}
}
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/
driver/driver_dynamic_convolution_forward_implicit
_gemm_v4r4_nchw_kcyx_nkhw.hpp
→
composable_kernel/include/
kernel_algorithm/transform_forward_convolution_into
_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
be49a8c5
#ifndef CK_
DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT
_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#ifndef CK_
TRANSFORM_FORWARD_CONVOLUTION_INTO
_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#define CK_
DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT
_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#define CK_
TRANSFORM_FORWARD_CONVOLUTION_INTO
_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "driver_dynamic_gemm_v1.hpp"
namespace
ck
{
namespace
ck
{
...
...
composable_kernel/include/
driver/driver_dynamic_convolution_forward_implicit
_gemm_v4r4_nhwc_kyxc_nhwk.hpp
→
composable_kernel/include/
kernel_algorithm/transform_forward_convolution_into
_gemm_v4r4_nhwc_kyxc_nhwk.hpp
View file @
be49a8c5
#ifndef CK_
DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT
_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
#ifndef CK_
TRANSFORM_FORWARD_CONVOLUTION_INTO
_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
#define CK_
DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT
_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
#define CK_
TRANSFORM_FORWARD_CONVOLUTION_INTO
_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "driver_dynamic_gemm_v1.hpp"
namespace
ck
{
namespace
ck
{
...
...
composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp
View file @
be49a8c5
...
@@ -29,8 +29,6 @@ template <index_t BlockSize,
...
@@ -29,8 +29,6 @@ template <index_t BlockSize,
index_t
DstVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
,
index_t
DstScalarPerVector
,
AddressSpace
SrcAddressSpace
,
AddressSpace
DstAddressSpace
,
index_t
SrcScalarStrideInVector
,
index_t
SrcScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
index_t
ThreadTransferSrcResetCoordinateAfterRun
,
index_t
ThreadTransferSrcResetCoordinateAfterRun
,
...
@@ -79,24 +77,25 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
...
@@ -79,24 +77,25 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
}
}
}
}
template
<
typename
SrcIteratorHacks
>
template
<
typename
SrcBuffer
,
typename
SrcIteratorHacks
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
Src
Data
*
p_src
,
const
Src
Buffer
&
src_buf
,
const
SrcIteratorHacks
&
src_iterator_hacks
)
const
SrcIteratorHacks
&
src_iterator_hacks
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
RunRead
(
src_desc
,
p_
src
,
src_iterator_hacks
);
threadwise_transfer_
.
RunRead
(
src_desc
,
src
_buf
,
src_iterator_hacks
);
}
}
}
}
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstData
*
p_dst
)
template
<
typename
DstBuffer
>
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
RunWrite
(
dst_desc
,
p_
dst
);
threadwise_transfer_
.
RunWrite
(
dst_desc
,
dst
_buf
);
}
}
}
}
...
@@ -152,8 +151,6 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
...
@@ -152,8 +151,6 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
DstScalarPerVector
,
DstScalarPerVector
,
SrcScalarStrideInVector
,
SrcScalarStrideInVector
,
DstScalarStrideInVector
,
DstScalarStrideInVector
,
SrcAddressSpace
,
DstAddressSpace
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadTransferDstResetCoordinateAfterRun
>
;
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
View file @
be49a8c5
...
@@ -115,8 +115,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
...
@@ -115,8 +115,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
const
BBlockBuffer
&
b_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
CThreadBuffer
&
c_thread_buf
)
const
{
{
auto
a_thread_buf
=
make_static_buffer
<
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
a_thread_buf
=
auto
b_thread_buf
=
make_static_buffer
<
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
constexpr
auto
threadwise_gemm
=
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
<
FloatA
,
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
<
FloatA
,
...
@@ -176,8 +178,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
...
@@ -176,8 +178,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
2
,
AThreadCopyScalarPerVector_M1
,
AThreadCopyScalarPerVector_M1
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
1
>
;
1
>
;
using
BThreadCopy
=
using
BThreadCopy
=
...
@@ -189,8 +189,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
...
@@ -189,8 +189,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
2
,
BThreadCopyScalarPerVector_N1
,
BThreadCopyScalarPerVector_N1
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
1
>
;
1
>
;
CIndex
c_thread_origin_data_idx_
;
CIndex
c_thread_origin_data_idx_
;
...
@@ -211,6 +209,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
...
@@ -211,6 +209,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
// 3. C:
// 3. C:
// 1. CThreadDesc is known at compile-time
// 1. CThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer
// 2. CThreadBuffer is StaticBuffer
// Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatA
,
typename
FloatB
,
typename
FloatB
,
...
@@ -312,8 +312,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
...
@@ -312,8 +312,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
const
BBlockBuffer
&
b_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
CThreadBuffer
&
c_thread_buf
)
const
{
{
auto
a_thread_buf
=
make_static_buffer
<
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
a_thread_buf
=
auto
b_thread_buf
=
make_static_buffer
<
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
constexpr
auto
threadwise_gemm
=
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
<
FloatA
,
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
<
FloatA
,
...
@@ -481,8 +483,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
...
@@ -481,8 +483,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
2
,
AThreadCopyScalarPerVector_M1
,
AThreadCopyScalarPerVector_M1
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
1
>
;
1
>
;
using
BThreadCopy
=
using
BThreadCopy
=
...
@@ -494,8 +494,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
...
@@ -494,8 +494,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
2
,
BThreadCopyScalarPerVector_N1
,
BThreadCopyScalarPerVector_N1
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
1
>
;
1
>
;
CIndex
c_thread_origin_data_idx_
;
CIndex
c_thread_origin_data_idx_
;
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
View file @
be49a8c5
...
@@ -49,8 +49,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
...
@@ -49,8 +49,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
ThreadGemmADataPerRead_K
,
ThreadGemmADataPerRead_K
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
1
>
;
1
>
;
__device__
BlockwiseGemm_km_kn_m0m1n0n1_v3
()
__device__
BlockwiseGemm_km_kn_m0m1n0n1_v3
()
...
@@ -140,7 +138,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
...
@@ -140,7 +138,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
static_assert
(
WPerThread
%
WoPerThreadSubC
==
0
,
""
);
static_assert
(
WPerThread
%
WoPerThreadSubC
==
0
,
""
);
// thread A buffer for GEMM
// thread A buffer for GEMM
StaticBuffer
<
FloatA
,
a_thread_mtx_
.
GetElementSpaceSize
()
>
a_thread_buf
;
StaticBuffer
<
AddressSpace
::
Vgpr
,
FloatA
,
a_thread_mtx_
.
GetElementSpaceSize
()
>
a_thread_buf
;
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v3
<
FloatA
,
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v3
<
FloatA
,
FloatB
,
FloatB
,
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
be49a8c5
...
@@ -14,29 +14,33 @@ namespace ck {
...
@@ -14,29 +14,33 @@ namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
AGlobalDesc
,
typename
FloatA
,
typename
FloatA
,
typename
BGlobalDesc
,
typename
FloatB
,
typename
FloatB
,
typename
CGlobalDesc
,
typename
FloatC
,
typename
FloatC
,
typename
AGlobalDesc
,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
typename
CBlockClusterDesc
,
typename
CBlockClusterDesc
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
bool
HasDoubleTailKBlockLoop
>
__global__
void
kernel_dynamic_gemm_v1
(
const
AGlobalDesc
a_k_m_global_desc
,
__global__
void
const
FloatA
*
__restrict__
p_a_global
,
#if CK_USE_LAUNCH_BOUNDS
const
BGlobalDesc
b_k_n_global_desc
,
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_dynamic_gemm_v1
(
const
FloatA
*
__restrict__
p_a_global
,
const
FloatB
*
__restrict__
p_b_global
,
const
FloatB
*
__restrict__
p_b_global
,
const
CGlobalDesc
c_m0_m1_n0_n1_global_desc
,
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_c_global
,
const
AGlobalDesc
a_k_m_global_desc
,
const
BGlobalDesc
b_k_n_global_desc
,
const
CGlobalDesc
c_m0_m1_n0_n1_global_desc
,
const
CBlockClusterDesc
c_block_cluster_desc
)
const
CBlockClusterDesc
c_block_cluster_desc
)
{
{
GridwiseGemm
{}.
Run
(
a_k_m_global_desc
,
GridwiseGemm
::
Run
(
p_a_global
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_c_global
,
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
,
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
...
@@ -46,21 +50,25 @@ __global__ void kernel_dynamic_gemm_v1(const AGlobalDesc a_k_m_global_desc,
...
@@ -46,21 +50,25 @@ __global__ void kernel_dynamic_gemm_v1(const AGlobalDesc a_k_m_global_desc,
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
AGlobalDesc
,
typename
FloatA
,
typename
FloatA
,
typename
BGlobalDesc
,
typename
FloatB
,
typename
FloatB
,
typename
CGlobalDesc
,
typename
FloatC
,
typename
FloatC
,
typename
AGlobalDesc
,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
typename
CBlockClusterDesc
,
typename
CBlockClusterDesc
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
bool
HasDoubleTailKBlockLoop
>
__global__
void
kernel_dynamic_gemm_v1
(
const
void
__CONSTANT__
*
p_a_k_m_global_desc
,
__global__
void
const
FloatA
*
__restrict__
p_a_global
,
#if CK_USE_LAUNCH_BOUNDS
const
void
__CONSTANT__
*
p_b_k_n_global_desc
,
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_dynamic_gemm_v1
(
const
FloatA
*
__restrict__
p_a_global
,
const
FloatB
*
__restrict__
p_b_global
,
const
FloatB
*
__restrict__
p_b_global
,
const
void
__CONSTANT__
*
p_c_m0_m1_n0_n1_global_desc
,
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_c_global
,
const
void
__CONSTANT__
*
p_a_k_m_global_desc
,
const
void
__CONSTANT__
*
p_b_k_n_global_desc
,
const
void
__CONSTANT__
*
p_c_m0_m1_n0_n1_global_desc
,
const
void
__CONSTANT__
*
p_c_block_cluster_desc
)
const
void
__CONSTANT__
*
p_c_block_cluster_desc
)
{
{
// first cast void __CONSTANT__ void* to void*
// first cast void __CONSTANT__ void* to void*
...
@@ -76,12 +84,12 @@ __global__ void kernel_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_d
...
@@ -76,12 +84,12 @@ __global__ void kernel_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_d
const
auto
c_block_cluster_desc
=
const
auto
c_block_cluster_desc
=
*
reinterpret_cast
<
const
CBlockClusterDesc
*>
((
const
void
*
)
p_c_block_cluster_desc
);
*
reinterpret_cast
<
const
CBlockClusterDesc
*>
((
const
void
*
)
p_c_block_cluster_desc
);
GridwiseGemm
{}.
Run
(
a_k_m_global_desc
,
GridwiseGemm
::
Run
(
p_a_global
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_c_global
,
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
,
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
...
@@ -161,22 +169,29 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -161,22 +169,29 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
}
}
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
&
a_k_m_global_desc
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_a_global
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatAB
*
__restrict__
p_b_global
,
const
CGlobalDesc
&
c_m0_m1_n0_n1_global_desc
,
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_c_global
,
const
AGlobalDesc
&
a_k_m_global_desc
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
CGlobalDesc
&
c_m0_m1_n0_n1_global_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
FloatAB
*
__restrict__
p_shared_block
,
FloatAB
*
__restrict__
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_a_global
,
a_k_m_global_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_b_global
,
b_k_n_global_desc
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_c_global
,
c_m0_m1_n0_n1_global_desc
.
GetElementSpaceSize
());
const
auto
K
=
a_k_m_global_desc
.
GetLength
(
I0
);
const
auto
K
=
a_k_m_global_desc
.
GetLength
(
I0
);
const
auto
M
=
a_k_m_global_desc
.
GetLength
(
I1
);
const
auto
M
=
a_k_m_global_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_global_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_global_desc
.
GetLength
(
I1
);
...
@@ -226,8 +241,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -226,8 +241,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
1
,
1
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_M
,
ABlockTransferDstScalarPerVector_M
,
AddressSpace
::
Global
,
AddressSpace
::
Lds
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
...
@@ -255,8 +268,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -255,8 +268,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
1
,
1
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_N
,
BBlockTransferDstScalarPerVector_N
,
AddressSpace
::
Global
,
AddressSpace
::
Lds
,
1
,
1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
...
@@ -331,8 +342,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -331,8 +342,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
// register allocation for output
// register allocation for output
auto
c_thread_buf
=
auto
c_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatAcc
>
(
make_static_buffer
<
FloatAcc
>
(
c_m0_m1_n0_n1_thread_desc
.
GetElementSpaceSize
());
c_m0_m1_n0_n1_thread_desc
.
GetElementSpaceSize
());
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_m0_m1_n0_n1_thread_desc
),
decltype
(
c_m0_m1_n0_n1_thread_desc
),
...
@@ -353,25 +364,23 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -353,25 +364,23 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr
auto
b_k_n_global_move_slice_window_iterator_hack
=
constexpr
auto
b_k_n_global_move_slice_window_iterator_hack
=
BGlobalMoveSliceWindowIteratorHacks
{};
BGlobalMoveSliceWindowIteratorHacks
{};
FloatAB
*
p_a_block_even
=
p_a_block_double
;
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
FloatAB
*
p_b_block_even
=
p_b_block_double
;
p_a_block_double
,
a_k_m_block_desc
.
GetElementSpaceSize
());
auto
b_block_even_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
p_b_block_double
,
b_k_n_block_desc
.
GetElementSpaceSize
());
FloatAB
*
p_a_block_odd
=
p_a_block_double
+
a_block_space_size
;
auto
a_block_odd_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
FloatAB
*
p_b_block_odd
=
p_b_block_double
+
b_block_space_size
;
p_a_block_double
+
a_block_space_size
,
a_k_m_block_desc
.
GetElementSpaceSize
());
auto
b_block_odd_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
auto
a_block_even_buf
=
make_dynamic_buffer
(
p_a_block_even
);
p_b_block_double
+
b_block_space_size
,
b_k_n_block_desc
.
GetElementSpaceSize
());
auto
b_block_even_buf
=
make_dynamic_buffer
(
p_b_block_even
);
auto
a_block_odd_buf
=
make_dynamic_buffer
(
p_a_block_odd
);
auto
b_block_odd_buf
=
make_dynamic_buffer
(
p_b_block_odd
);
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_
a_global
,
a_k_m_global_iterator_hacks
);
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
a_global
_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_
b_global
,
b_k_n_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
b_global
_buf
,
b_k_n_global_iterator_hacks
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_
a_block_
double
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_
even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
p_
b_block_
double
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_
even_buf
);
}
}
if
constexpr
(
HasMainKBlockLoop
)
if
constexpr
(
HasMainKBlockLoop
)
...
@@ -394,16 +403,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -394,16 +403,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_
a_global
,
a_k_m_global_iterator_hacks
);
a_k_m_global_desc
,
a_global
_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_
b_global
,
b_k_n_global_iterator_hacks
);
b_k_n_global_desc
,
b_global
_buf
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_
a_block_odd
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_odd
_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
p_
b_block_odd
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_odd
_buf
);
// odd iteration
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
...
@@ -417,16 +426,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -417,16 +426,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_
a_global
,
a_k_m_global_iterator_hacks
);
a_k_m_global_desc
,
a_global
_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_
b_global
,
b_k_n_global_iterator_hacks
);
b_k_n_global_desc
,
b_global
_buf
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_
a_block_even
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_even
_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
p_
b_block_even
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_even
_buf
);
k_block_data_begin
+=
2
*
KPerBlock
;
k_block_data_begin
+=
2
*
KPerBlock
;
}
while
(
k_block_data_begin
<
K
-
2
*
KPerBlock
);
}
while
(
k_block_data_begin
<
K
-
2
*
KPerBlock
);
...
@@ -445,15 +454,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -445,15 +454,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
__syncthreads
();
__syncthreads
();
// LDS double buffer: load last data from device mem
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_
a_global
,
a_k_m_global_iterator_hacks
);
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
a_global
_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_
b_global
,
b_k_n_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
b_global
_buf
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store last data to LDS
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_
a_block_
double
+
a_block_space_size
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_
odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
p_
b_block_
double
+
b_block_space_size
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_
odd_buf
);
__syncthreads
();
__syncthreads
();
...
@@ -488,8 +497,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -488,8 +497,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
CGlobalMemoryDataOperation
,
CGlobalMemoryDataOperation
,
1
,
1
,
true
>
{
true
>
{
...
@@ -502,32 +509,32 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -502,32 +509,32 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_thread_buf
,
c_m0_m1_n0_n1_global_desc
,
c_m0_m1_n0_n1_global_desc
,
p_
c_global
,
c_global
_buf
,
c_m0_m1_n0_n1_global_tensor_iterator_hacks
);
c_m0_m1_n0_n1_global_tensor_iterator_hacks
);
}
}
}
}
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
&
a_k_m_global_desc
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_a_global
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatAB
*
__restrict__
p_b_global
,
const
CGlobalDesc
&
c_m0_m1_n0_n1_global_desc
,
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_c_global
,
const
AGlobalDesc
&
a_k_m_global_desc
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
CGlobalDesc
&
c_m0_m1_n0_n1_global_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
{
constexpr
index_t
shared_block_size
=
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
constexpr
index_t
shared_block_size
=
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
Run
(
a_k_m_global_desc
,
Run
(
p_a_global
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_c_global
,
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
,
c_block_cluster_desc
,
p_shared_block
,
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
View file @
be49a8c5
...
@@ -84,6 +84,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -84,6 +84,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_a_global
,
a_e_k_global_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_b_global
,
b_e_n_ho_wo_global_desc
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_c_global
,
c_k_n_ho_wo_global_desc
.
GetElementSpaceSize
());
constexpr
auto
E
=
EPerBlock
*
3
*
3
;
constexpr
auto
E
=
EPerBlock
*
3
*
3
;
// const auto E = a_e_k_global_desc.GetLength(I0);
// const auto E = a_e_k_global_desc.GetLength(I0);
...
@@ -192,8 +199,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -192,8 +199,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
1
,
1
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K
,
ABlockTransferDstScalarPerVector_K
,
AddressSpace
::
Global
,
AddressSpace
::
Lds
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
...
@@ -216,19 +221,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -216,19 +221,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
,
InMemoryDataOperation
::
Set
,
1
,
1
,
true
>
(
b_e_n_ho_wo_global_desc
,
true
>
(
b_e_n_ho_wo_global_desc
,
make_multi_index
(
0
,
0
,
ho_thread_data_on_global
,
wo_thread_data_on_global
));
make_multi_index
(
0
,
0
,
ho_thread_data_on_global
,
wo_thread_data_on_global
));
FloatAB
*
p_a_block
=
p_shared_block
;
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
p_shared_block
,
a_e_k_desc
.
GetElementSpaceSize
());
auto
a_block_buf
=
make_dynamic_buffer
(
p_a_block
);
// register allocation for output
// register allocation for output
StaticBuffer
<
FloatAcc
,
c_k_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
>
c_thread_buf
;
StaticBuffer
<
AddressSpace
::
Vgpr
,
FloatAcc
,
c_k_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
>
c_thread_buf
;
// initialize output thread tensor
// initialize output thread tensor
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
...
@@ -250,21 +253,21 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -250,21 +253,21 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
BGlobalMoveSliceWindowIteratorHacks
{};
BGlobalMoveSliceWindowIteratorHacks
{};
// double regsiter buffer for b
// double regsiter buffer for b
StaticBuffer
<
FloatAB
,
b_e_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
>
b_thread_even_buf
,
StaticBuffer
<
AddressSpace
::
Vgpr
,
FloatAB
,
b_e_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
>
b_thread_odd_buf
;
b_thread_even_buf
,
b_thread_odd_buf
;
// LDS double buffer: preload data
// LDS double buffer: preload data
{
{
a_blockwise_copy
.
RunRead
(
a_e_k_global_desc
,
p_
a_global
,
a_e_k_global_iterator_hacks
);
a_blockwise_copy
.
RunRead
(
a_e_k_global_desc
,
a_global
_buf
,
a_e_k_global_iterator_hacks
);
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
p_
b_global
,
b_global
_buf
,
b_e_n_ho_wo_thread_desc
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_thread_even_buf
,
b_e_n_ho_wo_global_iterator_hacks
);
b_e_n_ho_wo_global_iterator_hacks
);
a_blockwise_copy
.
RunWrite
(
a_e_k_desc
,
p_
a_block
);
a_blockwise_copy
.
RunWrite
(
a_e_k_desc
,
a_block
_buf
);
}
}
__syncthreads
();
__syncthreads
();
...
@@ -282,7 +285,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -282,7 +285,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
p_
b_global
,
b_global
_buf
,
b_e_n_ho_wo_thread_desc
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_thread_odd_buf
,
...
@@ -298,7 +301,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -298,7 +301,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
p_
b_global
,
b_global
_buf
,
b_e_n_ho_wo_thread_desc
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_thread_even_buf
,
...
@@ -321,7 +324,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -321,7 +324,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
p_
b_global
,
b_global
_buf
,
b_e_n_ho_wo_thread_desc
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_thread_odd_buf
,
...
@@ -358,8 +361,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -358,8 +361,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
CGlobalMemoryDataOperation
,
CGlobalMemoryDataOperation
,
1
,
1
,
true
>
(
true
>
(
...
@@ -370,7 +371,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -370,7 +371,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_thread_buf
,
c_k_n_ho_wo_global_desc
,
c_k_n_ho_wo_global_desc
,
p_
c_global
,
c_global
_buf
,
c_k_n_ho_wo_global_tensor_iterator_hacks
);
c_k_n_ho_wo_global_tensor_iterator_hacks
);
}
}
}
}
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
View file @
be49a8c5
...
@@ -12,34 +12,89 @@
...
@@ -12,34 +12,89 @@
namespace
ck
{
namespace
ck
{
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
AGlobalDesc
,
typename
FloatA
,
typename
FloatA
,
typename
BGlobalDesc
,
typename
FloatB
,
typename
FloatB
,
typename
CGlobalDesc
,
typename
FloatC
,
typename
FloatC
,
typename
AGlobalDesc
,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
typename
CBlockClusterDesc
,
typename
CBlockClusterDesc
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
bool
HasDoubleTailKBlockLoop
>
__global__
void
kernel_dynamic_gemm_xdlops_v1
(
const
AGlobalDesc
a_k_m_global_desc
,
__global__
void
const
FloatA
*
__restrict__
p_a_global
,
#if CK_USE_LAUNCH_BOUNDS
const
BGlobalDesc
b_k_n_global_desc
,
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_dynamic_gemm_xdlops_v1
(
const
FloatA
*
__restrict__
p_a_global
,
const
FloatB
*
__restrict__
p_b_global
,
const
FloatB
*
__restrict__
p_b_global
,
const
CGlobalDesc
c_m0_m1_n0_n1_global_desc
,
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_c_global
,
const
AGlobalDesc
a_k_m_global_desc
,
const
BGlobalDesc
b_k_n_global_desc
,
const
CGlobalDesc
c_m0_m1_n0_n1_global_desc
,
const
CBlockClusterDesc
c_block_cluster_desc
)
const
CBlockClusterDesc
c_block_cluster_desc
)
{
{
GridwiseGemm
{}.
Run
(
a_k_m_global_desc
,
GridwiseGemm
::
Run
(
p_a_global
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
p_b_global
,
p_c_global
,
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template
<
typename
GridwiseGemm
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
AGlobalDesc
,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
typename
CBlockClusterDesc
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_dynamic_gemm_xdlops_v1
(
const
FloatA
*
__restrict__
p_a_global
,
const
FloatB
*
__restrict__
p_b_global
,
FloatC
*
__restrict__
p_c_global
,
const
void
__CONSTANT__
*
p_a_k_m_global_desc
,
const
void
__CONSTANT__
*
p_b_k_n_global_desc
,
const
void
__CONSTANT__
*
p_c_m0_m1_n0_n1_global_desc
,
const
void
__CONSTANT__
*
p_c_block_cluster_desc
)
{
// first cast void __CONSTANT__ void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const
auto
a_k_m_global_desc
=
*
reinterpret_cast
<
const
AGlobalDesc
*>
((
const
void
*
)
p_a_k_m_global_desc
);
const
auto
b_k_n_global_desc
=
*
reinterpret_cast
<
const
BGlobalDesc
*>
((
const
void
*
)
p_b_k_n_global_desc
);
const
auto
c_m0_m1_n0_n1_global_desc
=
*
reinterpret_cast
<
const
CGlobalDesc
*>
((
const
void
*
)
p_c_m0_m1_n0_n1_global_desc
);
const
auto
c_block_cluster_desc
=
*
reinterpret_cast
<
const
CBlockClusterDesc
*>
((
const
void
*
)
p_c_block_cluster_desc
);
GridwiseGemm
::
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
p_c_global
,
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
,
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
}
#endif
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAB
,
...
@@ -114,22 +169,29 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -114,22 +169,29 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
}
}
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
&
a_k_m_global_desc
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_a_global
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatAB
*
__restrict__
p_b_global
,
const
CGlobalDesc
&
c_m0_m1_n0_n1_global_desc
,
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_c_global
,
const
AGlobalDesc
&
a_k_m_global_desc
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
CGlobalDesc
&
c_m0_m1_n0_n1_global_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
FloatAB
*
__restrict__
p_shared_block
,
FloatAB
*
__restrict__
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_a_global
,
a_k_m_global_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_b_global
,
b_k_n_global_desc
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_c_global
,
c_m0_m1_n0_n1_global_desc
.
GetElementSpaceSize
());
const
auto
K
=
a_k_m_global_desc
.
GetLength
(
I0
);
const
auto
K
=
a_k_m_global_desc
.
GetLength
(
I0
);
const
auto
M
=
a_k_m_global_desc
.
GetLength
(
I1
);
const
auto
M
=
a_k_m_global_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_global_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_global_desc
.
GetLength
(
I1
);
...
@@ -179,8 +241,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -179,8 +241,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
1
,
1
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_M
,
ABlockTransferDstScalarPerVector_M
,
AddressSpace
::
Global
,
AddressSpace
::
Lds
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
...
@@ -208,8 +268,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -208,8 +268,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
1
,
1
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_N
,
BBlockTransferDstScalarPerVector_N
,
AddressSpace
::
Global
,
AddressSpace
::
Lds
,
1
,
1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
...
@@ -284,8 +342,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -284,8 +342,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
// register allocation for output
// register allocation for output
auto
c_thread_buf
=
auto
c_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatAcc
>
(
make_static_buffer
<
FloatAcc
>
(
c_m0_m1_n0_n1_thread_desc
.
GetElementSpaceSize
());
c_m0_m1_n0_n1_thread_desc
.
GetElementSpaceSize
());
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_m0_m1_n0_n1_thread_desc
),
decltype
(
c_m0_m1_n0_n1_thread_desc
),
...
@@ -306,25 +364,23 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -306,25 +364,23 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
constexpr
auto
b_k_n_global_move_slice_window_iterator_hack
=
constexpr
auto
b_k_n_global_move_slice_window_iterator_hack
=
BGlobalMoveSliceWindowIteratorHacks
{};
BGlobalMoveSliceWindowIteratorHacks
{};
FloatAB
*
p_a_block_even
=
p_a_block_double
;
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
FloatAB
*
p_b_block_even
=
p_b_block_double
;
p_a_block_double
,
a_k_m_block_desc
.
GetElementSpaceSize
());
auto
b_block_even_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
FloatAB
*
p_a_block_odd
=
p_a_block_double
+
a_block_space_size
;
p_b_block_double
,
b_k_n_block_desc
.
GetElementSpaceSize
());
FloatAB
*
p_b_block_odd
=
p_b_block_double
+
b_block_space_size
;
auto
a_block_even_buf
=
make_dynamic_buffer
(
p_a_block_even
);
auto
a_block_odd_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
auto
b_block_even_buf
=
make_dynamic_buffer
(
p_b_block_even
);
p_a_block_double
+
a_block_space_size
,
a_k_m_block_desc
.
GetElementSpaceSize
());
auto
b_block_odd_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
auto
a_block_odd_buf
=
make_dynamic_buffer
(
p_a_block_odd
);
p_b_block_double
+
b_block_space_size
,
b_k_n_block_desc
.
GetElementSpaceSize
());
auto
b_block_odd_buf
=
make_dynamic_buffer
(
p_b_block_odd
);
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_
a_global
,
a_k_m_global_iterator_hacks
);
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
a_global
_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_
b_global
,
b_k_n_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
b_global
_buf
,
b_k_n_global_iterator_hacks
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_
a_block_
double
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_
even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
p_
b_block_
double
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_
even_buf
);
}
}
if
constexpr
(
HasMainKBlockLoop
)
if
constexpr
(
HasMainKBlockLoop
)
...
@@ -347,16 +403,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -347,16 +403,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_
a_global
,
a_k_m_global_iterator_hacks
);
a_k_m_global_desc
,
a_global
_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_
b_global
,
b_k_n_global_iterator_hacks
);
b_k_n_global_desc
,
b_global
_buf
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_
a_block_odd
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_odd
_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
p_
b_block_odd
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_odd
_buf
);
// odd iteration
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
...
@@ -370,16 +426,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -370,16 +426,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_
a_global
,
a_k_m_global_iterator_hacks
);
a_k_m_global_desc
,
a_global
_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_
b_global
,
b_k_n_global_iterator_hacks
);
b_k_n_global_desc
,
b_global
_buf
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_
a_block_even
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_even
_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
p_
b_block_even
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_even
_buf
);
k_block_data_begin
+=
2
*
KPerBlock
;
k_block_data_begin
+=
2
*
KPerBlock
;
}
while
(
k_block_data_begin
<
K
-
2
*
KPerBlock
);
}
while
(
k_block_data_begin
<
K
-
2
*
KPerBlock
);
...
@@ -398,15 +454,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -398,15 +454,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
__syncthreads
();
__syncthreads
();
// LDS double buffer: load last data from device mem
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_
a_global
,
a_k_m_global_iterator_hacks
);
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
a_global
_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_
b_global
,
b_k_n_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
b_global
_buf
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store last data to LDS
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_
a_block_
double
+
a_block_space_size
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_
odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
p_
b_block_
double
+
b_block_space_size
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_
odd_buf
);
__syncthreads
();
__syncthreads
();
...
@@ -441,8 +497,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -441,8 +497,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
CGlobalMemoryDataOperation
,
CGlobalMemoryDataOperation
,
1
,
1
,
true
>
{
true
>
{
...
@@ -455,32 +509,32 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -455,32 +509,32 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_thread_buf
,
c_m0_m1_n0_n1_global_desc
,
c_m0_m1_n0_n1_global_desc
,
p_
c_global
,
c_global
_buf
,
c_m0_m1_n0_n1_global_tensor_iterator_hacks
);
c_m0_m1_n0_n1_global_tensor_iterator_hacks
);
}
}
}
}
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
&
a_k_m_global_desc
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_a_global
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatAB
*
__restrict__
p_b_global
,
const
CGlobalDesc
&
c_m0_m1_n0_n1_global_desc
,
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_c_global
,
const
AGlobalDesc
&
a_k_m_global_desc
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
CGlobalDesc
&
c_m0_m1_n0_n1_global_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
{
constexpr
index_t
shared_block_size
=
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
constexpr
index_t
shared_block_size
=
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
Run
(
a_k_m_global_desc
,
Run
(
p_a_global
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_c_global
,
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
,
c_block_cluster_desc
,
p_shared_block
,
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
...
...
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
be49a8c5
This diff is collapsed.
Click to expand it.
composable_kernel/include/utility/amd_buffer_addressing_v2.hpp
View file @
be49a8c5
...
@@ -323,7 +323,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
...
@@ -323,7 +323,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
}
}
else
if
constexpr
(
N
==
2
)
else
if
constexpr
(
N
==
2
)
{
{
#if !CK_WORKAROUND_SWDEV_XXXXXX
#if !CK_WORKAROUND_SWDEV_XXXXXX
_INT8_BUFFER_LOAD_STORE_ISSUE
return
__llvm_amdgcn_raw_buffer_load_i8x2
(
return
__llvm_amdgcn_raw_buffer_load_i8x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
#else
#else
...
@@ -335,7 +335,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
...
@@ -335,7 +335,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
}
}
else
if
constexpr
(
N
==
4
)
else
if
constexpr
(
N
==
4
)
{
{
#if !CK_WORKAROUND_SWDEV_XXXXXX
#if !CK_WORKAROUND_SWDEV_XXXXXX
_INT8_BUFFER_LOAD_STORE_ISSUE
return
__llvm_amdgcn_raw_buffer_load_i8x4
(
return
__llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
#else
#else
...
@@ -347,7 +347,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
...
@@ -347,7 +347,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
}
}
else
if
constexpr
(
N
==
8
)
else
if
constexpr
(
N
==
8
)
{
{
#if !CK_WORKAROUND_SWDEV_XXXXXX
#if !CK_WORKAROUND_SWDEV_XXXXXX
_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type
<
int8_t
,
8
>
tmp
;
vector_type
<
int8_t
,
8
>
tmp
;
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
0
>
{})
=
__llvm_amdgcn_raw_buffer_load_i8x4
(
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
0
>
{})
=
__llvm_amdgcn_raw_buffer_load_i8x4
(
...
@@ -369,7 +369,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
...
@@ -369,7 +369,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
}
}
else
if
constexpr
(
N
==
16
)
else
if
constexpr
(
N
==
16
)
{
{
#if !CK_WORKAROUND_SWDEV_XXXXXX
#if !CK_WORKAROUND_SWDEV_XXXXXX
_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type
<
int8_t
,
16
>
tmp
;
vector_type
<
int8_t
,
16
>
tmp
;
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
0
>
{})
=
__llvm_amdgcn_raw_buffer_load_i8x4
(
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
0
>
{})
=
__llvm_amdgcn_raw_buffer_load_i8x4
(
...
@@ -483,7 +483,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
...
@@ -483,7 +483,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
}
}
else
if
constexpr
(
N
==
2
)
else
if
constexpr
(
N
==
2
)
{
{
#if !CK_WORKAROUND_SWDEV_XXXXXX
#if !CK_WORKAROUND_SWDEV_XXXXXX
_INT8_BUFFER_LOAD_STORE_ISSUE
__llvm_amdgcn_raw_buffer_store_i8x2
(
src_thread_data
,
__llvm_amdgcn_raw_buffer_store_i8x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_thread_addr_offset
,
...
@@ -499,7 +499,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
...
@@ -499,7 +499,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
}
}
else
if
constexpr
(
N
==
4
)
else
if
constexpr
(
N
==
4
)
{
{
#if !CK_WORKAROUND_SWDEV_XXXXXX
#if !CK_WORKAROUND_SWDEV_XXXXXX
_INT8_BUFFER_LOAD_STORE_ISSUE
__llvm_amdgcn_raw_buffer_store_i8x4
(
src_thread_data
,
__llvm_amdgcn_raw_buffer_store_i8x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_thread_addr_offset
,
...
...
composable_kernel/include/utility/buffer.hpp
deleted
100644 → 0
View file @
bcdc330d
#ifndef CK_BUFFER_HPP
#define CK_BUFFER_HPP
#include "statically_indexed_array.hpp"
namespace
ck
{
template
<
typename
T
,
index_t
N
>
struct
StaticBuffer
:
public
StaticallyIndexedArray
<
T
,
N
>
{
using
type
=
T
;
using
base
=
StaticallyIndexedArray
<
T
,
N
>
;
__host__
__device__
constexpr
StaticBuffer
()
:
base
{}
{}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
};
template
<
typename
T
,
index_t
N
>
__host__
__device__
constexpr
auto
make_static_buffer
(
Number
<
N
>
)
{
return
StaticBuffer
<
T
,
N
>
{};
}
template
<
typename
T
>
struct
DynamicBuffer
{
using
type
=
T
;
T
*
p_data_
;
__host__
__device__
constexpr
DynamicBuffer
(
T
*
p_data
)
:
p_data_
{
p_data
}
{}
__host__
__device__
constexpr
const
T
&
operator
[](
index_t
i
)
const
{
return
p_data_
[
i
];
}
__host__
__device__
constexpr
T
&
operator
()(
index_t
i
)
{
return
p_data_
[
i
];
}
template
<
typename
X
,
typename
std
::
enable_if
<
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
const
auto
Get
(
index_t
i
)
const
{
return
*
reinterpret_cast
<
const
X
*>
(
&
p_data_
[
i
]);
}
template
<
typename
X
,
typename
std
::
enable_if
<
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
void
Set
(
index_t
i
,
const
X
&
x
)
{
*
reinterpret_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
true
;
}
};
template
<
typename
T
>
__host__
__device__
constexpr
auto
make_dynamic_buffer
(
T
*
p
)
{
return
DynamicBuffer
<
T
>
{
p
};
}
}
// namespace ck
#endif
composable_kernel/include/utility/common_header.hpp
View file @
be49a8c5
...
@@ -8,7 +8,6 @@
...
@@ -8,7 +8,6 @@
#include "container_element_picker.hpp"
#include "container_element_picker.hpp"
#include "data_type.hpp"
#include "data_type.hpp"
#include "float_type.hpp"
#include "float_type.hpp"
#include "buffer.hpp"
#include "functional.hpp"
#include "functional.hpp"
#include "functional2.hpp"
#include "functional2.hpp"
#include "functional3.hpp"
#include "functional3.hpp"
...
@@ -25,6 +24,8 @@
...
@@ -25,6 +24,8 @@
#include "type.hpp"
#include "type.hpp"
#include "utility.hpp"
#include "utility.hpp"
#include "magic_division.hpp"
#include "magic_division.hpp"
#include "static_buffer.hpp"
#include "dynamic_buffer.hpp"
#if CK_USE_AMD_INLINE_ASM
#if CK_USE_AMD_INLINE_ASM
#include "amd_inline_asm.hpp"
#include "amd_inline_asm.hpp"
...
...
composable_kernel/include/utility/config.amd.hpp.in
View file @
be49a8c5
...
@@ -143,8 +143,13 @@
...
@@ -143,8 +143,13 @@
#endif
#endif
// workaround for compiler crash when using buffer load/store for i8
// workaround for compiler crash when using buffer load/store for i8
#ifndef CK_WORKAROUND_SWDEV_XXXXXX
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX 1
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1
#endif
// workaround for compiler crash when using buffer load/store for i8
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif
#endif
namespace ck {
namespace ck {
...
@@ -154,6 +159,7 @@ enum AddressSpace
...
@@ -154,6 +159,7 @@ enum AddressSpace
Generic,
Generic,
Global,
Global,
Lds,
Lds,
Sgpr,
Vgpr
Vgpr
};
};
...
...
composable_kernel/include/utility/dynamic_buffer.hpp
0 → 100644
View file @
be49a8c5
#ifndef CK_DYNAMIC_BUFFER_HPP
#define CK_DYNAMIC_BUFFER_HPP
namespace
ck
{
#include "amd_buffer_addressing_v2.hpp"
template
<
AddressSpace
BufferAddressSpace
,
typename
T
,
typename
ElementSpaceSize
>
struct
DynamicBuffer
{
using
type
=
T
;
T
*
p_data_
;
ElementSpaceSize
element_space_size_
;
__host__
__device__
constexpr
DynamicBuffer
(
T
*
p_data
,
ElementSpaceSize
element_space_size
)
:
p_data_
{
p_data
},
element_space_size_
{
element_space_size
}
{
}
__host__
__device__
static
constexpr
AddressSpace
GetAddressSpace
()
{
return
BufferAddressSpace
;
}
__host__
__device__
constexpr
const
T
&
operator
[](
index_t
i
)
const
{
return
p_data_
[
i
];
}
__host__
__device__
constexpr
T
&
operator
()(
index_t
i
)
{
return
p_data_
[
i
];
}
template
<
typename
X
,
typename
std
::
enable_if
<
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
const
auto
Get
(
index_t
i
,
bool
is_valid_offset
)
const
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X need to be multiple T"
);
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
if
constexpr
(
GetAddressSpace
()
==
AddressSpace
::
Global
)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
return
amd_buffer_load_v2
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
t_per_x
>
(
p_data_
,
i
,
is_valid_offset
,
element_space_size_
);
#else
return
is_valid_offset
?
*
reinterpret_cast
<
const
X
*>
(
&
p_data_
[
i
])
:
X
{
0
};
#endif
}
else
{
return
is_valid_offset
?
*
reinterpret_cast
<
const
X
*>
(
&
p_data_
[
i
])
:
X
{
0
};
}
}
template
<
typename
X
,
typename
std
::
enable_if
<
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
void
Set
(
index_t
i
,
bool
is_valid_offset
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X need to be multiple T"
);
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
if
constexpr
(
GetAddressSpace
()
==
AddressSpace
::
Global
)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
amd_buffer_store_v2
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_offset
,
element_space_size_
);
#else
if
(
is_valid_offset
)
{
*
reinterpret_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
}
#endif
}
else
if
constexpr
(
GetAddressSpace
()
==
AddressSpace
::
Lds
)
{
if
(
is_valid_offset
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
*
reinterpret_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
#else
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient
// ISA, so I try to let compiler emit use IR "store<i32, 4>" which would be lower to
// ds_write_b128
// TODO: remove this after compiler fix
if
constexpr
(
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
,
int8_t
>::
value
)
{
static_assert
(
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x4_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x4_t
>::
value
)
||
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x8_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x8_t
>::
value
)
||
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x16_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x16_t
>::
value
),
"wrong! not implemented for this combination, please add implementation"
);
if
constexpr
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x4_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x4_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
reinterpret_cast
<
int32_t
*>
(
&
p_data_
[
i
])
=
*
reinterpret_cast
<
const
int32_t
*>
(
&
x
);
}
if
constexpr
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x8_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x8_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
reinterpret_cast
<
int32x2_t
*>
(
&
p_data_
[
i
])
=
*
reinterpret_cast
<
const
int32x2_t
*>
(
&
x
);
}
if
constexpr
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x16_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x16_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
reinterpret_cast
<
int32x4_t
*>
(
&
p_data_
[
i
])
=
*
reinterpret_cast
<
const
int32x4_t
*>
(
&
x
);
}
}
else
{
*
reinterpret_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
}
#endif
}
}
else
{
if
(
is_valid_offset
)
{
*
reinterpret_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
}
}
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
true
;
}
};
template
<
AddressSpace
BufferAddressSpace
=
AddressSpace
::
Generic
,
typename
T
,
typename
ElementSpaceSize
>
__host__
__device__
constexpr
auto
make_dynamic_buffer
(
T
*
p
,
ElementSpaceSize
element_space_size
)
{
return
DynamicBuffer
<
BufferAddressSpace
,
T
,
ElementSpaceSize
>
{
p
,
element_space_size
};
}
}
// namespace ck
#endif
composable_kernel/include/utility/static_buffer.hpp
0 → 100644
View file @
be49a8c5
#ifndef CK_STATIC_BUFFER_HPP
#define CK_STATIC_BUFFER_HPP
#include "statically_indexed_array.hpp"
namespace
ck
{
template
<
AddressSpace
BufferAddressSpace
,
typename
T
,
index_t
N
>
struct
StaticBuffer
:
public
StaticallyIndexedArray
<
T
,
N
>
{
using
type
=
T
;
using
base
=
StaticallyIndexedArray
<
T
,
N
>
;
__host__
__device__
constexpr
StaticBuffer
()
:
base
{}
{}
__host__
__device__
static
constexpr
AddressSpace
GetAddressSpace
()
{
return
BufferAddressSpace
;
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
};
template
<
AddressSpace
BufferAddressSpace
=
AddressSpace
::
Generic
,
typename
T
,
index_t
N
>
__host__
__device__
constexpr
auto
make_static_buffer
(
Number
<
N
>
)
{
return
StaticBuffer
<
BufferAddressSpace
,
T
,
N
>
{};
}
}
// namespace ck
#endif
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
be49a8c5
#include <unistd.h>
#include <unistd.h>
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "driver_dynamic_gemm_v1.hpp"
template
<
class
TInWei
,
template
<
class
TInWei
,
ck
::
index_t
InWeiVectorSize
,
ck
::
index_t
InWeiVectorSize
,
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
View file @
be49a8c5
#include <unistd.h>
#include <unistd.h>
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_dynamic_gemm_v1.hpp"
template
<
class
TInWei
,
template
<
class
TInWei
,
ck
::
index_t
InWeiVectorSize
,
ck
::
index_t
InWeiVectorSize
,
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
View file @
be49a8c5
...
@@ -63,12 +63,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
...
@@ -63,12 +63,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
#if 0
#if 0
// run-time variables
// run-time variables
const auto in_n_c_hi_wi_desc =
const auto in_n_c
0
_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
to
_multi_index(
InDesc::GetLengths()
));
make_dynamic_naive_tensor_descriptor_packed_v2(
make
_multi_index(
N, C0, Hi, Wi
));
const auto wei_k_c_y_x_desc =
const auto wei_k_c
0
_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
to
_multi_index(
WeiDesc::GetLengths()
));
make_dynamic_naive_tensor_descriptor_packed_v2(
make
_multi_index(
K, C0, Y, X
));
const auto out_n_k_ho_wo_desc =
const auto out_n_k
0
_ho_wo_
k1_
desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
to
_multi_index(
OutDesc::GetLengths()
));
make_dynamic_naive_tensor_descriptor_packed_v2(
make
_multi_index(
N, K0, Ho, Wo, K1
));
const auto conv_strides = to_multi_index(ConvStrides{});
const auto conv_strides = to_multi_index(ConvStrides{});
const auto conv_dilations = to_multi_index(ConvDilations{});
const auto conv_dilations = to_multi_index(ConvDilations{});
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment