Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
be49a8c5
"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "820320ef0da518c1e42a710eff9486e46a9dfb47"
Commit
be49a8c5
authored
May 12, 2021
by
Jing Zhang
Browse files
merge master
parents
bcdc330d
71d6b19d
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
877 additions
and
629 deletions
+877
-629
composable_kernel/include/driver/driver_dynamic_gemm_v1.hpp
composable_kernel/include/driver/driver_dynamic_gemm_v1.hpp
+65
-73
composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v1.hpp
...e_kernel/include/driver/driver_dynamic_gemm_xdlops_v1.hpp
+185
-64
composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp
...orm_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp
+2
-3
composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp
...orm_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp
+2
-3
composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp
...sor_operation/blockwise_dynamic_tensor_slice_transfer.hpp
+6
-9
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
+10
-12
composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
+1
-3
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+103
-96
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
...nel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
+20
-19
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
...include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
+130
-76
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+121
-182
composable_kernel/include/utility/amd_buffer_addressing_v2.hpp
...sable_kernel/include/utility/amd_buffer_addressing_v2.hpp
+6
-6
composable_kernel/include/utility/buffer.hpp
composable_kernel/include/utility/buffer.hpp
+0
-72
composable_kernel/include/utility/common_header.hpp
composable_kernel/include/utility/common_header.hpp
+2
-1
composable_kernel/include/utility/config.amd.hpp.in
composable_kernel/include/utility/config.amd.hpp.in
+8
-2
composable_kernel/include/utility/dynamic_buffer.hpp
composable_kernel/include/utility/dynamic_buffer.hpp
+173
-0
composable_kernel/include/utility/static_buffer.hpp
composable_kernel/include/utility/static_buffer.hpp
+33
-0
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+2
-1
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
...convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
+2
-1
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
+6
-6
No files found.
composable_kernel/include/driver/driver_dynamic_gemm_v1.hpp
View file @
be49a8c5
...
@@ -146,16 +146,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
...
@@ -146,16 +146,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
AGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
integral_constant
<
bool
,
true
>
,
true
,
integral_constant
<
bool
,
true
>
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
...
@@ -163,28 +163,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
...
@@ -163,28 +163,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
a_k_m_global_desc
,
p_a_global
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_c_global
,
c_block_cluster_desc
,
a_k_m_global_desc
,
integral_constant
<
bool
,
true
>
{},
b_k_n_global_desc
,
integral_constant
<
bool
,
true
>
{});
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
);
}
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
AGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
integral_constant
<
bool
,
true
>
,
true
,
integral_constant
<
bool
,
false
>
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
...
@@ -192,28 +190,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
...
@@ -192,28 +190,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
a_k_m_global_desc
,
p_a_global
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_c_global
,
c_block_cluster_desc
,
a_k_m_global_desc
,
integral_constant
<
bool
,
true
>
{},
b_k_n_global_desc
,
integral_constant
<
bool
,
false
>
{});
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
);
}
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
AGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
integral_constant
<
bool
,
false
>
,
false
,
integral_constant
<
bool
,
true
>
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
...
@@ -221,28 +217,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
...
@@ -221,28 +217,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
a_k_m_global_desc
,
p_a_global
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_c_global
,
c_block_cluster_desc
,
a_k_m_global_desc
,
integral_constant
<
bool
,
false
>
{},
b_k_n_global_desc
,
integral_constant
<
bool
,
true
>
{});
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
);
}
}
else
else
{
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
AGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
integral_constant
<
bool
,
false
>
,
false
,
integral_constant
<
bool
,
false
>
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
...
@@ -250,15 +244,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
...
@@ -250,15 +244,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
a_k_m_global_desc
,
p_a_global
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_c_global
,
c_block_cluster_desc
,
a_k_m_global_desc
,
integral_constant
<
bool
,
false
>
{},
b_k_n_global_desc
,
integral_constant
<
bool
,
false
>
{});
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
);
}
}
return
ave_time
;
return
ave_time
;
...
@@ -277,13 +269,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
...
@@ -277,13 +269,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
AGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
const
FloatAB
*
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
true
,
true
,
true
>
;
true
>
;
...
@@ -295,23 +287,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
...
@@ -295,23 +287,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
p_a_global
,
p_a_global
,
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
p_b_global
,
p_b_global
,
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
p_c_global
,
p_c_global
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
remove_reference_t
<
AGlobalDesc
>
,
FloatAB
,
FloatAB
,
remove_reference_t
<
BGlobalDesc
>
,
FloatAB
,
FloatAB
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
true
,
true
,
false
>
;
false
>
;
...
@@ -323,23 +315,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
...
@@ -323,23 +315,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
p_a_global
,
p_a_global
,
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
p_b_global
,
p_b_global
,
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
p_c_global
,
p_c_global
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
remove_reference_t
<
AGlobalDesc
>
,
FloatAB
,
FloatAB
,
remove_reference_t
<
BGlobalDesc
>
,
FloatAB
,
FloatAB
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
false
,
false
,
true
>
;
true
>
;
...
@@ -351,23 +343,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
...
@@ -351,23 +343,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
p_a_global
,
p_a_global
,
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
p_b_global
,
p_b_global
,
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
p_c_global
,
p_c_global
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
}
else
else
{
{
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_v1
<
gridwise_gemm
,
remove_reference_t
<
AGlobalDesc
>
,
FloatAB
,
FloatAB
,
remove_reference_t
<
BGlobalDesc
>
,
FloatAB
,
FloatAB
,
remove_reference_t
<
CGlobalDesc
>
,
FloatC
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
false
,
false
,
false
>
;
false
>
;
...
@@ -379,12 +371,12 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
...
@@ -379,12 +371,12 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
p_a_global
,
p_a_global
,
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
p_b_global
,
p_b_global
,
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
p_c_global
,
p_c_global
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
}
...
...
composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v1.hpp
View file @
be49a8c5
...
@@ -141,20 +141,21 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
...
@@ -141,20 +141,21 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
const
bool
has_double_tail_k_block_loop
=
(
K
/
KPerBlock
)
%
2
==
0
;
const
bool
has_double_tail_k_block_loop
=
(
K
/
KPerBlock
)
%
2
==
0
;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float
ave_time
=
0
;
float
ave_time
=
0
;
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_xdlops_v1
<
gridwise_gemm
,
remove_reference_t
<
AGlobalDesc
>
,
FloatAB
,
const
FloatAB
*
,
FloatAB
,
remove_reference_t
<
BGlobalDesc
>
,
FloatC
,
const
FloatAB
*
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
C
GlobalDesc
>
,
remove_reference_t
<
B
GlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
integral_constant
<
bool
,
true
>
,
true
,
integral_constant
<
bool
,
true
>
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
...
@@ -162,28 +163,26 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
...
@@ -162,28 +163,26 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
a_k_m_global_desc
,
p_a_global
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_c_global
,
c_block_cluster_desc
,
a_k_m_global_desc
,
integral_constant
<
bool
,
true
>
{},
b_k_n_global_desc
,
integral_constant
<
bool
,
true
>
{});
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
);
}
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_xdlops_v1
<
gridwise_gemm
,
remove_reference_t
<
AGlobalDesc
>
,
FloatAB
,
const
FloatAB
*
,
FloatAB
,
remove_reference_t
<
BGlobalDesc
>
,
FloatC
,
const
FloatAB
*
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
C
GlobalDesc
>
,
remove_reference_t
<
B
GlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
integral_constant
<
bool
,
true
>
,
true
,
integral_constant
<
bool
,
false
>
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
...
@@ -191,28 +190,26 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
...
@@ -191,28 +190,26 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
a_k_m_global_desc
,
p_a_global
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_c_global
,
c_block_cluster_desc
,
a_k_m_global_desc
,
integral_constant
<
bool
,
true
>
{},
b_k_n_global_desc
,
integral_constant
<
bool
,
false
>
{});
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
);
}
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_xdlops_v1
<
gridwise_gemm
,
remove_reference_t
<
AGlobalDesc
>
,
FloatAB
,
const
FloatAB
*
,
FloatAB
,
remove_reference_t
<
BGlobalDesc
>
,
FloatC
,
const
FloatAB
*
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
C
GlobalDesc
>
,
remove_reference_t
<
B
GlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
integral_constant
<
bool
,
false
>
,
false
,
integral_constant
<
bool
,
true
>
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
...
@@ -220,28 +217,26 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
...
@@ -220,28 +217,26 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
a_k_m_global_desc
,
p_a_global
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_c_global
,
c_block_cluster_desc
,
a_k_m_global_desc
,
integral_constant
<
bool
,
false
>
{},
b_k_n_global_desc
,
integral_constant
<
bool
,
true
>
{});
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
);
}
}
else
else
{
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
const
auto
kernel
=
kernel_dynamic_gemm_xdlops_v1
<
gridwise_gemm
,
remove_reference_t
<
AGlobalDesc
>
,
FloatAB
,
const
FloatAB
*
,
FloatAB
,
remove_reference_t
<
BGlobalDesc
>
,
FloatC
,
const
FloatAB
*
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
C
GlobalDesc
>
,
remove_reference_t
<
B
GlobalDesc
>
,
FloatC
*
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
integral_constant
<
bool
,
false
>
,
false
,
integral_constant
<
bool
,
false
>
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
...
@@ -249,18 +244,144 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
...
@@ -249,18 +244,144 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
a_k_m_global_desc
,
p_a_global
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_c_global
,
c_block_cluster_desc
,
a_k_m_global_desc
,
integral_constant
<
bool
,
false
>
{},
b_k_n_global_desc
,
integral_constant
<
bool
,
false
>
{});
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
);
}
}
return
ave_time
;
return
ave_time
;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem
a_k_m_global_desc_device_buf
(
sizeof
(
AGlobalDesc
));
DeviceMem
b_k_n_global_desc_device_buf
(
sizeof
(
BGlobalDesc
));
DeviceMem
c_m0_m1_n0_n1_global_desc_device_buf
(
sizeof
(
CGlobalDesc
));
DeviceMem
c_block_cluster_desc_device_buf
(
sizeof
(
c_block_cluster_desc
));
a_k_m_global_desc_device_buf
.
ToDevice
(
&
a_k_m_global_desc
);
b_k_n_global_desc_device_buf
.
ToDevice
(
&
b_k_n_global_desc
);
c_m0_m1_n0_n1_global_desc_device_buf
.
ToDevice
(
&
c_m0_m1_n0_n1_global_desc
);
c_block_cluster_desc_device_buf
.
ToDevice
(
&
c_block_cluster_desc
);
float
ave_time
=
0
;
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_dynamic_gemm_xdlops_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
true
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
p_a_global
,
p_b_global
,
p_c_global
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_dynamic_gemm_xdlops_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
true
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
p_a_global
,
p_b_global
,
p_c_global
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_dynamic_gemm_xdlops_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
false
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
p_a_global
,
p_b_global
,
p_c_global
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
else
{
const
auto
kernel
=
kernel_dynamic_gemm_xdlops_v1
<
gridwise_gemm
,
FloatAB
,
FloatAB
,
FloatC
,
remove_reference_t
<
AGlobalDesc
>
,
remove_reference_t
<
BGlobalDesc
>
,
remove_reference_t
<
CGlobalDesc
>
,
remove_reference_t
<
CBlockClusterDesc
>
,
false
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
p_a_global
,
p_b_global
,
p_c_global
,
(
void
__CONSTANT__
*
)
a_k_m_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
b_k_n_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_m0_m1_n0_n1_global_desc_device_buf
.
GetDeviceBuffer
(),
(
void
__CONSTANT__
*
)
c_block_cluster_desc_device_buf
.
GetDeviceBuffer
());
}
return
ave_time
;
#endif
}
}
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/
driver/driver_dynamic_convolution_forward_implicit
_gemm_v4r4_nchw_kcyx_nkhw.hpp
→
composable_kernel/include/
kernel_algorithm/transform_forward_convolution_into
_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
be49a8c5
#ifndef CK_
DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT
_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#ifndef CK_
TRANSFORM_FORWARD_CONVOLUTION_INTO
_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#define CK_
DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT
_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#define CK_
TRANSFORM_FORWARD_CONVOLUTION_INTO
_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "driver_dynamic_gemm_v1.hpp"
namespace
ck
{
namespace
ck
{
...
...
composable_kernel/include/
driver/driver_dynamic_convolution_forward_implicit
_gemm_v4r4_nhwc_kyxc_nhwk.hpp
→
composable_kernel/include/
kernel_algorithm/transform_forward_convolution_into
_gemm_v4r4_nhwc_kyxc_nhwk.hpp
View file @
be49a8c5
#ifndef CK_
DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT
_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
#ifndef CK_
TRANSFORM_FORWARD_CONVOLUTION_INTO
_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
#define CK_
DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT
_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
#define CK_
TRANSFORM_FORWARD_CONVOLUTION_INTO
_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "driver_dynamic_gemm_v1.hpp"
namespace
ck
{
namespace
ck
{
...
...
composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp
View file @
be49a8c5
...
@@ -29,8 +29,6 @@ template <index_t BlockSize,
...
@@ -29,8 +29,6 @@ template <index_t BlockSize,
index_t
DstVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
,
index_t
DstScalarPerVector
,
AddressSpace
SrcAddressSpace
,
AddressSpace
DstAddressSpace
,
index_t
SrcScalarStrideInVector
,
index_t
SrcScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
index_t
ThreadTransferSrcResetCoordinateAfterRun
,
index_t
ThreadTransferSrcResetCoordinateAfterRun
,
...
@@ -79,24 +77,25 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
...
@@ -79,24 +77,25 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
}
}
}
}
template
<
typename
SrcIteratorHacks
>
template
<
typename
SrcBuffer
,
typename
SrcIteratorHacks
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
Src
Data
*
p_src
,
const
Src
Buffer
&
src_buf
,
const
SrcIteratorHacks
&
src_iterator_hacks
)
const
SrcIteratorHacks
&
src_iterator_hacks
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
RunRead
(
src_desc
,
p_
src
,
src_iterator_hacks
);
threadwise_transfer_
.
RunRead
(
src_desc
,
src
_buf
,
src_iterator_hacks
);
}
}
}
}
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstData
*
p_dst
)
template
<
typename
DstBuffer
>
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
{
threadwise_transfer_
.
RunWrite
(
dst_desc
,
p_
dst
);
threadwise_transfer_
.
RunWrite
(
dst_desc
,
dst
_buf
);
}
}
}
}
...
@@ -152,8 +151,6 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
...
@@ -152,8 +151,6 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
DstScalarPerVector
,
DstScalarPerVector
,
SrcScalarStrideInVector
,
SrcScalarStrideInVector
,
DstScalarStrideInVector
,
DstScalarStrideInVector
,
SrcAddressSpace
,
DstAddressSpace
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadTransferDstResetCoordinateAfterRun
>
;
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
View file @
be49a8c5
...
@@ -115,8 +115,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
...
@@ -115,8 +115,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
const
BBlockBuffer
&
b_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
CThreadBuffer
&
c_thread_buf
)
const
{
{
auto
a_thread_buf
=
make_static_buffer
<
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
a_thread_buf
=
auto
b_thread_buf
=
make_static_buffer
<
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
constexpr
auto
threadwise_gemm
=
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
<
FloatA
,
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
<
FloatA
,
...
@@ -176,8 +178,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
...
@@ -176,8 +178,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
2
,
AThreadCopyScalarPerVector_M1
,
AThreadCopyScalarPerVector_M1
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
1
>
;
1
>
;
using
BThreadCopy
=
using
BThreadCopy
=
...
@@ -189,8 +189,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
...
@@ -189,8 +189,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
2
,
BThreadCopyScalarPerVector_N1
,
BThreadCopyScalarPerVector_N1
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
1
>
;
1
>
;
CIndex
c_thread_origin_data_idx_
;
CIndex
c_thread_origin_data_idx_
;
...
@@ -211,6 +209,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
...
@@ -211,6 +209,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
// 3. C:
// 3. C:
// 1. CThreadDesc is known at compile-time
// 1. CThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer
// 2. CThreadBuffer is StaticBuffer
// Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatA
,
typename
FloatB
,
typename
FloatB
,
...
@@ -312,8 +312,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
...
@@ -312,8 +312,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
const
BBlockBuffer
&
b_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
CThreadBuffer
&
c_thread_buf
)
const
{
{
auto
a_thread_buf
=
make_static_buffer
<
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
a_thread_buf
=
auto
b_thread_buf
=
make_static_buffer
<
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
constexpr
auto
threadwise_gemm
=
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
<
FloatA
,
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
<
FloatA
,
...
@@ -481,8 +483,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
...
@@ -481,8 +483,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
2
,
AThreadCopyScalarPerVector_M1
,
AThreadCopyScalarPerVector_M1
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
1
>
;
1
>
;
using
BThreadCopy
=
using
BThreadCopy
=
...
@@ -494,8 +494,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
...
@@ -494,8 +494,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
2
,
BThreadCopyScalarPerVector_N1
,
BThreadCopyScalarPerVector_N1
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
1
>
;
1
>
;
CIndex
c_thread_origin_data_idx_
;
CIndex
c_thread_origin_data_idx_
;
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
View file @
be49a8c5
...
@@ -49,8 +49,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
...
@@ -49,8 +49,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
ThreadGemmADataPerRead_K
,
ThreadGemmADataPerRead_K
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
1
>
;
1
>
;
__device__
BlockwiseGemm_km_kn_m0m1n0n1_v3
()
__device__
BlockwiseGemm_km_kn_m0m1n0n1_v3
()
...
@@ -140,7 +138,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
...
@@ -140,7 +138,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
static_assert
(
WPerThread
%
WoPerThreadSubC
==
0
,
""
);
static_assert
(
WPerThread
%
WoPerThreadSubC
==
0
,
""
);
// thread A buffer for GEMM
// thread A buffer for GEMM
StaticBuffer
<
FloatA
,
a_thread_mtx_
.
GetElementSpaceSize
()
>
a_thread_buf
;
StaticBuffer
<
AddressSpace
::
Vgpr
,
FloatA
,
a_thread_mtx_
.
GetElementSpaceSize
()
>
a_thread_buf
;
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v3
<
FloatA
,
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v3
<
FloatA
,
FloatB
,
FloatB
,
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
be49a8c5
...
@@ -14,54 +14,62 @@ namespace ck {
...
@@ -14,54 +14,62 @@ namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
AGlobalDesc
,
typename
FloatA
,
typename
FloatA
,
typename
BGlobalDesc
,
typename
FloatB
,
typename
FloatB
,
typename
CGlobalDesc
,
typename
FloatC
,
typename
FloatC
,
typename
AGlobalDesc
,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
typename
CBlockClusterDesc
,
typename
CBlockClusterDesc
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
bool
HasDoubleTailKBlockLoop
>
__global__
void
kernel_dynamic_gemm_v1
(
const
AGlobalDesc
a_k_m_global_desc
,
__global__
void
const
FloatA
*
__restrict__
p_a_global
,
#if CK_USE_LAUNCH_BOUNDS
const
BGlobalDesc
b_k_n_global_desc
,
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
const
FloatB
*
__restrict__
p_b_global
,
#endif
const
CGlobalDesc
c_m0_m1_n0_n1_global_desc
,
kernel_dynamic_gemm_v1
(
const
FloatA
*
__restrict__
p_a_global
,
FloatC
*
__restrict__
p_c_global
,
const
FloatB
*
__restrict__
p_b_global
,
const
CBlockClusterDesc
c_block_cluster_desc
)
FloatC
*
__restrict__
p_c_global
,
const
AGlobalDesc
a_k_m_global_desc
,
const
BGlobalDesc
b_k_n_global_desc
,
const
CGlobalDesc
c_m0_m1_n0_n1_global_desc
,
const
CBlockClusterDesc
c_block_cluster_desc
)
{
{
GridwiseGemm
{}.
Run
(
a_k_m
_global
_desc
,
GridwiseGemm
::
Run
(
p_a
_global
,
p_
a
_global
,
p_
b
_global
,
b_k_n
_global
_desc
,
p_c
_global
,
p_b
_global
,
a_k_m
_global
_desc
,
c_m0_m1_n0_n1
_global_desc
,
b_k_n
_global_desc
,
p_c
_global
,
c_m0_m1_n0_n1
_global
_desc
,
c_block_cluster_desc
,
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer
// pass tensor descriptor by __CONSTANT__ void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
AGlobalDesc
,
typename
FloatA
,
typename
FloatA
,
typename
BGlobalDesc
,
typename
FloatB
,
typename
FloatB
,
typename
CGlobalDesc
,
typename
FloatC
,
typename
FloatC
,
typename
AGlobalDesc
,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
typename
CBlockClusterDesc
,
typename
CBlockClusterDesc
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
bool
HasDoubleTailKBlockLoop
>
__global__
void
kernel_dynamic_gemm_v1
(
const
void
__CONSTANT__
*
p_a_k_m_global_desc
,
__global__
void
const
FloatA
*
__restrict__
p_a_global
,
#if CK_USE_LAUNCH_BOUNDS
const
void
__CONSTANT__
*
p_b_k_n_global_desc
,
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
const
FloatB
*
__restrict__
p_b_global
,
#endif
const
void
__CONSTANT__
*
p_c_m0_m1_n0_n1_global_desc
,
kernel_dynamic_gemm_v1
(
const
FloatA
*
__restrict__
p_a_global
,
FloatC
*
__restrict__
p_c_global
,
const
FloatB
*
__restrict__
p_b_global
,
const
void
__CONSTANT__
*
p_c_block_cluster_desc
)
FloatC
*
__restrict__
p_c_global
,
const
void
__CONSTANT__
*
p_a_k_m_global_desc
,
const
void
__CONSTANT__
*
p_b_k_n_global_desc
,
const
void
__CONSTANT__
*
p_c_m0_m1_n0_n1_global_desc
,
const
void
__CONSTANT__
*
p_c_block_cluster_desc
)
{
{
// first cast void __CONSTANT__ void* to void*
// first cast void __CONSTANT__ void* to void*
// second cast void* to Desc*
// second cast void* to Desc*
...
@@ -76,15 +84,15 @@ __global__ void kernel_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_d
...
@@ -76,15 +84,15 @@ __global__ void kernel_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_d
const
auto
c_block_cluster_desc
=
const
auto
c_block_cluster_desc
=
*
reinterpret_cast
<
const
CBlockClusterDesc
*>
((
const
void
*
)
p_c_block_cluster_desc
);
*
reinterpret_cast
<
const
CBlockClusterDesc
*>
((
const
void
*
)
p_c_block_cluster_desc
);
GridwiseGemm
{}.
Run
(
a_k_m
_global
_desc
,
GridwiseGemm
::
Run
(
p_a
_global
,
p_
a
_global
,
p_
b
_global
,
b_k_n
_global
_desc
,
p_c
_global
,
p_b
_global
,
a_k_m
_global
_desc
,
c_m0_m1_n0_n1
_global_desc
,
b_k_n
_global_desc
,
p_c
_global
,
c_m0_m1_n0_n1
_global
_desc
,
c_block_cluster_desc
,
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
}
#endif
#endif
...
@@ -161,22 +169,29 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -161,22 +169,29 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
}
}
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
&
a_k_m
_global
_desc
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a
_global
,
const
FloatAB
*
__restrict__
p_
a
_global
,
const
FloatAB
*
__restrict__
p_
b
_global
,
const
BGlobalDesc
&
b_k_n
_global
_desc
,
FloatC
*
__restrict__
p_c
_global
,
const
FloatAB
*
__restrict__
p_b
_global
,
const
AGlobalDesc
&
a_k_m
_global
_desc
,
const
C
GlobalDesc
&
c_m0_m1_n0_n1
_global_desc
,
const
B
GlobalDesc
&
b_k_n
_global_desc
,
FloatC
*
__restrict__
p_c
_global
,
const
CGlobalDesc
&
c_m0_m1_n0_n1
_global
_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
FloatAB
*
__restrict__
p_shared_block
,
FloatAB
*
__restrict__
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_a_global
,
a_k_m_global_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_b_global
,
b_k_n_global_desc
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_c_global
,
c_m0_m1_n0_n1_global_desc
.
GetElementSpaceSize
());
const
auto
K
=
a_k_m_global_desc
.
GetLength
(
I0
);
const
auto
K
=
a_k_m_global_desc
.
GetLength
(
I0
);
const
auto
M
=
a_k_m_global_desc
.
GetLength
(
I1
);
const
auto
M
=
a_k_m_global_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_global_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_global_desc
.
GetLength
(
I1
);
...
@@ -226,8 +241,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -226,8 +241,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
1
,
1
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_M
,
ABlockTransferDstScalarPerVector_M
,
AddressSpace
::
Global
,
AddressSpace
::
Lds
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
...
@@ -255,8 +268,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -255,8 +268,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
1
,
1
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_N
,
BBlockTransferDstScalarPerVector_N
,
AddressSpace
::
Global
,
AddressSpace
::
Lds
,
1
,
1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
...
@@ -331,8 +342,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -331,8 +342,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
// register allocation for output
// register allocation for output
auto
c_thread_buf
=
auto
c_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatAcc
>
(
make_static_buffer
<
FloatAcc
>
(
c_m0_m1_n0_n1_thread_desc
.
GetElementSpaceSize
());
c_m0_m1_n0_n1_thread_desc
.
GetElementSpaceSize
());
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_m0_m1_n0_n1_thread_desc
),
decltype
(
c_m0_m1_n0_n1_thread_desc
),
...
@@ -353,25 +364,23 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -353,25 +364,23 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr
auto
b_k_n_global_move_slice_window_iterator_hack
=
constexpr
auto
b_k_n_global_move_slice_window_iterator_hack
=
BGlobalMoveSliceWindowIteratorHacks
{};
BGlobalMoveSliceWindowIteratorHacks
{};
FloatAB
*
p_a_block_even
=
p_a_block_double
;
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
FloatAB
*
p_b_block_even
=
p_b_block_double
;
p_a_block_double
,
a_k_m_block_desc
.
GetElementSpaceSize
());
auto
b_block_even_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
FloatAB
*
p_a_block_odd
=
p_a_block_double
+
a_block_space_size
;
p_b_block_double
,
b_k_n_block_desc
.
GetElementSpaceSize
());
FloatAB
*
p_b_block_odd
=
p_b_block_double
+
b_block_space_size
;
auto
a_block_even_buf
=
make_dynamic_buffer
(
p_a_block_even
);
auto
a_block_odd_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
auto
b_block_even_buf
=
make_dynamic_buffer
(
p_b_block_even
);
p_a_block_double
+
a_block_space_size
,
a_k_m_block_desc
.
GetElementSpaceSize
());
auto
b_block_odd_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
auto
a_block_odd_buf
=
make_dynamic_buffer
(
p_a_block_odd
);
p_b_block_double
+
b_block_space_size
,
b_k_n_block_desc
.
GetElementSpaceSize
());
auto
b_block_odd_buf
=
make_dynamic_buffer
(
p_b_block_odd
);
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_
a_global
,
a_k_m_global_iterator_hacks
);
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
a_global
_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_
b_global
,
b_k_n_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
b_global
_buf
,
b_k_n_global_iterator_hacks
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_
a_block_
double
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_
even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
p_
b_block_
double
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_
even_buf
);
}
}
if
constexpr
(
HasMainKBlockLoop
)
if
constexpr
(
HasMainKBlockLoop
)
...
@@ -394,16 +403,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -394,16 +403,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_
a_global
,
a_k_m_global_iterator_hacks
);
a_k_m_global_desc
,
a_global
_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_
b_global
,
b_k_n_global_iterator_hacks
);
b_k_n_global_desc
,
b_global
_buf
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_
a_block_odd
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_odd
_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
p_
b_block_odd
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_odd
_buf
);
// odd iteration
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
...
@@ -417,16 +426,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -417,16 +426,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_
a_global
,
a_k_m_global_iterator_hacks
);
a_k_m_global_desc
,
a_global
_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_
b_global
,
b_k_n_global_iterator_hacks
);
b_k_n_global_desc
,
b_global
_buf
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_
a_block_even
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_even
_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
p_
b_block_even
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_even
_buf
);
k_block_data_begin
+=
2
*
KPerBlock
;
k_block_data_begin
+=
2
*
KPerBlock
;
}
while
(
k_block_data_begin
<
K
-
2
*
KPerBlock
);
}
while
(
k_block_data_begin
<
K
-
2
*
KPerBlock
);
...
@@ -445,15 +454,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -445,15 +454,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
__syncthreads
();
__syncthreads
();
// LDS double buffer: load last data from device mem
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_
a_global
,
a_k_m_global_iterator_hacks
);
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
a_global
_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_
b_global
,
b_k_n_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
b_global
_buf
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store last data to LDS
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_
a_block_
double
+
a_block_space_size
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_
odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
p_
b_block_
double
+
b_block_space_size
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_
odd_buf
);
__syncthreads
();
__syncthreads
();
...
@@ -488,8 +497,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -488,8 +497,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
CGlobalMemoryDataOperation
,
CGlobalMemoryDataOperation
,
1
,
1
,
true
>
{
true
>
{
...
@@ -502,32 +509,32 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -502,32 +509,32 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_thread_buf
,
c_m0_m1_n0_n1_global_desc
,
c_m0_m1_n0_n1_global_desc
,
p_
c_global
,
c_global
_buf
,
c_m0_m1_n0_n1_global_tensor_iterator_hacks
);
c_m0_m1_n0_n1_global_tensor_iterator_hacks
);
}
}
}
}
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
&
a_k_m
_global
_desc
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a
_global
,
const
FloatAB
*
__restrict__
p_
a
_global
,
const
FloatAB
*
__restrict__
p_
b
_global
,
const
BGlobalDesc
&
b_k_n
_global
_desc
,
FloatC
*
__restrict__
p_c
_global
,
const
FloatAB
*
__restrict__
p_b
_global
,
const
AGlobalDesc
&
a_k_m
_global
_desc
,
const
C
GlobalDesc
&
c_m0_m1_n0_n1
_global_desc
,
const
B
GlobalDesc
&
b_k_n
_global_desc
,
FloatC
*
__restrict__
p_c
_global
,
const
CGlobalDesc
&
c_m0_m1_n0_n1
_global
_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
{
constexpr
index_t
shared_block_size
=
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
constexpr
index_t
shared_block_size
=
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
Run
(
a_k_m_global_desc
,
Run
(
p_a_global
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_c_global
,
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
,
c_block_cluster_desc
,
p_shared_block
,
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
View file @
be49a8c5
...
@@ -84,6 +84,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -84,6 +84,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_a_global
,
a_e_k_global_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_b_global
,
b_e_n_ho_wo_global_desc
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_c_global
,
c_k_n_ho_wo_global_desc
.
GetElementSpaceSize
());
constexpr
auto
E
=
EPerBlock
*
3
*
3
;
constexpr
auto
E
=
EPerBlock
*
3
*
3
;
// const auto E = a_e_k_global_desc.GetLength(I0);
// const auto E = a_e_k_global_desc.GetLength(I0);
...
@@ -192,8 +199,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -192,8 +199,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
1
,
1
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K
,
ABlockTransferDstScalarPerVector_K
,
AddressSpace
::
Global
,
AddressSpace
::
Lds
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
...
@@ -216,19 +221,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -216,19 +221,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
InMemoryDataOperation
::
Set
,
InMemoryDataOperation
::
Set
,
1
,
1
,
true
>
(
b_e_n_ho_wo_global_desc
,
true
>
(
b_e_n_ho_wo_global_desc
,
make_multi_index
(
0
,
0
,
ho_thread_data_on_global
,
wo_thread_data_on_global
));
make_multi_index
(
0
,
0
,
ho_thread_data_on_global
,
wo_thread_data_on_global
));
FloatAB
*
p_a_block
=
p_shared_block
;
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
p_shared_block
,
a_e_k_desc
.
GetElementSpaceSize
());
auto
a_block_buf
=
make_dynamic_buffer
(
p_a_block
);
// register allocation for output
// register allocation for output
StaticBuffer
<
FloatAcc
,
c_k_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
>
c_thread_buf
;
StaticBuffer
<
AddressSpace
::
Vgpr
,
FloatAcc
,
c_k_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
>
c_thread_buf
;
// initialize output thread tensor
// initialize output thread tensor
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
...
@@ -250,21 +253,21 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -250,21 +253,21 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
BGlobalMoveSliceWindowIteratorHacks
{};
BGlobalMoveSliceWindowIteratorHacks
{};
// double regsiter buffer for b
// double regsiter buffer for b
StaticBuffer
<
FloatAB
,
b_e_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
>
b_thread_even_buf
,
StaticBuffer
<
AddressSpace
::
Vgpr
,
FloatAB
,
b_e_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
>
b_thread_odd_buf
;
b_thread_even_buf
,
b_thread_odd_buf
;
// LDS double buffer: preload data
// LDS double buffer: preload data
{
{
a_blockwise_copy
.
RunRead
(
a_e_k_global_desc
,
p_
a_global
,
a_e_k_global_iterator_hacks
);
a_blockwise_copy
.
RunRead
(
a_e_k_global_desc
,
a_global
_buf
,
a_e_k_global_iterator_hacks
);
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
p_
b_global
,
b_global
_buf
,
b_e_n_ho_wo_thread_desc
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_thread_even_buf
,
b_e_n_ho_wo_global_iterator_hacks
);
b_e_n_ho_wo_global_iterator_hacks
);
a_blockwise_copy
.
RunWrite
(
a_e_k_desc
,
p_
a_block
);
a_blockwise_copy
.
RunWrite
(
a_e_k_desc
,
a_block
_buf
);
}
}
__syncthreads
();
__syncthreads
();
...
@@ -282,7 +285,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -282,7 +285,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
p_
b_global
,
b_global
_buf
,
b_e_n_ho_wo_thread_desc
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_thread_odd_buf
,
...
@@ -298,7 +301,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -298,7 +301,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
p_
b_global
,
b_global
_buf
,
b_e_n_ho_wo_thread_desc
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_thread_even_buf
,
...
@@ -321,7 +324,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -321,7 +324,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
p_
b_global
,
b_global
_buf
,
b_e_n_ho_wo_thread_desc
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_thread_odd_buf
,
...
@@ -358,8 +361,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -358,8 +361,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
CGlobalMemoryDataOperation
,
CGlobalMemoryDataOperation
,
1
,
1
,
true
>
(
true
>
(
...
@@ -370,7 +371,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
...
@@ -370,7 +371,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_thread_buf
,
c_k_n_ho_wo_global_desc
,
c_k_n_ho_wo_global_desc
,
p_
c_global
,
c_global
_buf
,
c_k_n_ho_wo_global_tensor_iterator_hacks
);
c_k_n_ho_wo_global_tensor_iterator_hacks
);
}
}
}
}
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
View file @
be49a8c5
...
@@ -12,34 +12,89 @@
...
@@ -12,34 +12,89 @@
namespace
ck
{
namespace
ck
{
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
AGlobalDesc
,
typename
FloatA
,
typename
FloatA
,
typename
BGlobalDesc
,
typename
FloatB
,
typename
FloatB
,
typename
FloatC
,
typename
AGlobalDesc
,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
typename
CGlobalDesc
,
typename
CBlockClusterDesc
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_dynamic_gemm_xdlops_v1
(
const
FloatA
*
__restrict__
p_a_global
,
const
FloatB
*
__restrict__
p_b_global
,
FloatC
*
__restrict__
p_c_global
,
const
AGlobalDesc
a_k_m_global_desc
,
const
BGlobalDesc
b_k_n_global_desc
,
const
CGlobalDesc
c_m0_m1_n0_n1_global_desc
,
const
CBlockClusterDesc
c_block_cluster_desc
)
{
GridwiseGemm
::
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template
<
typename
GridwiseGemm
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
FloatC
,
typename
AGlobalDesc
,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
typename
CBlockClusterDesc
,
typename
CBlockClusterDesc
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
bool
HasDoubleTailKBlockLoop
>
__global__
void
kernel_dynamic_gemm_xdlops_v1
(
const
AGlobalDesc
a_k_m_global_desc
,
__global__
void
const
FloatA
*
__restrict__
p_a_global
,
#if CK_USE_LAUNCH_BOUNDS
const
BGlobalDesc
b_k_n_global_desc
,
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
const
FloatB
*
__restrict__
p_b_global
,
#endif
const
CGlobalDesc
c_m0_m1_n0_n1_global_desc
,
kernel_dynamic_gemm_xdlops_v1
(
const
FloatA
*
__restrict__
p_a_global
,
FloatC
*
__restrict__
p_c_global
,
const
FloatB
*
__restrict__
p_b_global
,
const
CBlockClusterDesc
c_block_cluster_desc
)
FloatC
*
__restrict__
p_c_global
,
const
void
__CONSTANT__
*
p_a_k_m_global_desc
,
const
void
__CONSTANT__
*
p_b_k_n_global_desc
,
const
void
__CONSTANT__
*
p_c_m0_m1_n0_n1_global_desc
,
const
void
__CONSTANT__
*
p_c_block_cluster_desc
)
{
{
GridwiseGemm
{}.
Run
(
a_k_m_global_desc
,
// first cast void __CONSTANT__ void* to void*
p_a_global
,
// second cast void* to Desc*
b_k_n_global_desc
,
// the copy constructor of tensor descriptor doesn't take address_space(4)
p_b_global
,
const
auto
a_k_m_global_desc
=
c_m0_m1_n0_n1_global_desc
,
*
reinterpret_cast
<
const
AGlobalDesc
*>
((
const
void
*
)
p_a_k_m_global_desc
);
p_c_global
,
const
auto
b_k_n_global_desc
=
c_block_cluster_desc
,
*
reinterpret_cast
<
const
BGlobalDesc
*>
((
const
void
*
)
p_b_k_n_global_desc
);
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
const
auto
c_m0_m1_n0_n1_global_desc
=
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
*
reinterpret_cast
<
const
CGlobalDesc
*>
((
const
void
*
)
p_c_m0_m1_n0_n1_global_desc
);
const
auto
c_block_cluster_desc
=
*
reinterpret_cast
<
const
CBlockClusterDesc
*>
((
const
void
*
)
p_c_block_cluster_desc
);
GridwiseGemm
::
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
}
#endif
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAB
,
...
@@ -114,22 +169,29 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -114,22 +169,29 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
}
}
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
&
a_k_m
_global
_desc
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a
_global
,
const
FloatAB
*
__restrict__
p_
a
_global
,
const
FloatAB
*
__restrict__
p_
b
_global
,
const
BGlobalDesc
&
b_k_n
_global
_desc
,
FloatC
*
__restrict__
p_c
_global
,
const
FloatAB
*
__restrict__
p_b
_global
,
const
AGlobalDesc
&
a_k_m
_global
_desc
,
const
C
GlobalDesc
&
c_m0_m1_n0_n1
_global_desc
,
const
B
GlobalDesc
&
b_k_n
_global_desc
,
FloatC
*
__restrict__
p_c
_global
,
const
CGlobalDesc
&
c_m0_m1_n0_n1
_global
_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
FloatAB
*
__restrict__
p_shared_block
,
FloatAB
*
__restrict__
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_a_global
,
a_k_m_global_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_b_global
,
b_k_n_global_desc
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_c_global
,
c_m0_m1_n0_n1_global_desc
.
GetElementSpaceSize
());
const
auto
K
=
a_k_m_global_desc
.
GetLength
(
I0
);
const
auto
K
=
a_k_m_global_desc
.
GetLength
(
I0
);
const
auto
M
=
a_k_m_global_desc
.
GetLength
(
I1
);
const
auto
M
=
a_k_m_global_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_global_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_global_desc
.
GetLength
(
I1
);
...
@@ -179,8 +241,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -179,8 +241,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
1
,
1
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_M
,
ABlockTransferDstScalarPerVector_M
,
AddressSpace
::
Global
,
AddressSpace
::
Lds
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
...
@@ -208,8 +268,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -208,8 +268,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
1
,
1
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_N
,
BBlockTransferDstScalarPerVector_N
,
AddressSpace
::
Global
,
AddressSpace
::
Lds
,
1
,
1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
...
@@ -284,8 +342,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -284,8 +342,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
// register allocation for output
// register allocation for output
auto
c_thread_buf
=
auto
c_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatAcc
>
(
make_static_buffer
<
FloatAcc
>
(
c_m0_m1_n0_n1_thread_desc
.
GetElementSpaceSize
());
c_m0_m1_n0_n1_thread_desc
.
GetElementSpaceSize
());
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_m0_m1_n0_n1_thread_desc
),
decltype
(
c_m0_m1_n0_n1_thread_desc
),
...
@@ -306,25 +364,23 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -306,25 +364,23 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
constexpr
auto
b_k_n_global_move_slice_window_iterator_hack
=
constexpr
auto
b_k_n_global_move_slice_window_iterator_hack
=
BGlobalMoveSliceWindowIteratorHacks
{};
BGlobalMoveSliceWindowIteratorHacks
{};
FloatAB
*
p_a_block_even
=
p_a_block_double
;
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
FloatAB
*
p_b_block_even
=
p_b_block_double
;
p_a_block_double
,
a_k_m_block_desc
.
GetElementSpaceSize
());
auto
b_block_even_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
p_b_block_double
,
b_k_n_block_desc
.
GetElementSpaceSize
());
FloatAB
*
p_a_block_odd
=
p_a_block_double
+
a_block_space_size
;
auto
a_block_odd_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
FloatAB
*
p_b_block_odd
=
p_b_block_double
+
b_block_space_size
;
p_a_block_double
+
a_block_space_size
,
a_k_m_block_desc
.
GetElementSpaceSize
());
auto
b_block_odd_buf
=
make_dynamic_buffer
<
AddressSpace
::
Lds
>
(
auto
a_block_even_buf
=
make_dynamic_buffer
(
p_a_block_even
);
p_b_block_double
+
b_block_space_size
,
b_k_n_block_desc
.
GetElementSpaceSize
());
auto
b_block_even_buf
=
make_dynamic_buffer
(
p_b_block_even
);
auto
a_block_odd_buf
=
make_dynamic_buffer
(
p_a_block_odd
);
auto
b_block_odd_buf
=
make_dynamic_buffer
(
p_b_block_odd
);
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_
a_global
,
a_k_m_global_iterator_hacks
);
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
a_global
_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_
b_global
,
b_k_n_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
b_global
_buf
,
b_k_n_global_iterator_hacks
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_
a_block_
double
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_
even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
p_
b_block_
double
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_
even_buf
);
}
}
if
constexpr
(
HasMainKBlockLoop
)
if
constexpr
(
HasMainKBlockLoop
)
...
@@ -347,16 +403,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -347,16 +403,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_
a_global
,
a_k_m_global_iterator_hacks
);
a_k_m_global_desc
,
a_global
_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_
b_global
,
b_k_n_global_iterator_hacks
);
b_k_n_global_desc
,
b_global
_buf
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_
a_block_odd
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_odd
_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
p_
b_block_odd
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_odd
_buf
);
// odd iteration
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k_m_global_desc
,
...
@@ -370,16 +426,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -370,16 +426,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_
a_global
,
a_k_m_global_iterator_hacks
);
a_k_m_global_desc
,
a_global
_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_
b_global
,
b_k_n_global_iterator_hacks
);
b_k_n_global_desc
,
b_global
_buf
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_
a_block_even
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_even
_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
p_
b_block_even
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_even
_buf
);
k_block_data_begin
+=
2
*
KPerBlock
;
k_block_data_begin
+=
2
*
KPerBlock
;
}
while
(
k_block_data_begin
<
K
-
2
*
KPerBlock
);
}
while
(
k_block_data_begin
<
K
-
2
*
KPerBlock
);
...
@@ -398,15 +454,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -398,15 +454,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
__syncthreads
();
__syncthreads
();
// LDS double buffer: load last data from device mem
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_
a_global
,
a_k_m_global_iterator_hacks
);
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
a_global
_buf
,
a_k_m_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_
b_global
,
b_k_n_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
b_global
_buf
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store last data to LDS
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_
a_block_
double
+
a_block_space_size
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_
odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
p_
b_block_
double
+
b_block_space_size
);
b_blockwise_copy
.
RunWrite
(
b_k_n_block_desc
,
b_block_
odd_buf
);
__syncthreads
();
__syncthreads
();
...
@@ -441,8 +497,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -441,8 +497,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
CGlobalMemoryDataOperation
,
CGlobalMemoryDataOperation
,
1
,
1
,
true
>
{
true
>
{
...
@@ -455,32 +509,32 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -455,32 +509,32 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_thread_buf
,
c_m0_m1_n0_n1_global_desc
,
c_m0_m1_n0_n1_global_desc
,
p_
c_global
,
c_global
_buf
,
c_m0_m1_n0_n1_global_tensor_iterator_hacks
);
c_m0_m1_n0_n1_global_tensor_iterator_hacks
);
}
}
}
}
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
&
a_k_m
_global
_desc
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a
_global
,
const
FloatAB
*
__restrict__
p_
a
_global
,
const
FloatAB
*
__restrict__
p_
b
_global
,
const
BGlobalDesc
&
b_k_n
_global
_desc
,
FloatC
*
__restrict__
p_c
_global
,
const
FloatAB
*
__restrict__
p_b
_global
,
const
AGlobalDesc
&
a_k_m
_global
_desc
,
const
C
GlobalDesc
&
c_m0_m1_n0_n1
_global_desc
,
const
B
GlobalDesc
&
b_k_n
_global_desc
,
FloatC
*
__restrict__
p_c
_global
,
const
CGlobalDesc
&
c_m0_m1_n0_n1
_global
_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
{
constexpr
index_t
shared_block_size
=
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
constexpr
index_t
shared_block_size
=
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
Run
(
a_k_m_global_desc
,
Run
(
p_a_global
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_c_global
,
a_k_m_global_desc
,
b_k_n_global_desc
,
c_m0_m1_n0_n1_global_desc
,
c_block_cluster_desc
,
c_block_cluster_desc
,
p_shared_block
,
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
...
...
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
be49a8c5
...
@@ -54,8 +54,6 @@ template <typename SrcData,
...
@@ -54,8 +54,6 @@ template <typename SrcData,
typename
DimAccessOrder
,
typename
DimAccessOrder
,
index_t
DstVectorDim
,
index_t
DstVectorDim
,
index_t
DstScalarPerVector
,
index_t
DstScalarPerVector
,
AddressSpace
SrcAddressSpace
,
AddressSpace
DstAddressSpace
,
InMemoryDataOperation
DstInMemOp
,
InMemoryDataOperation
DstInMemOp
,
index_t
DstScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
bool
DstResetCoordinateAfterRun
,
bool
DstResetCoordinateAfterRun
,
...
@@ -72,7 +70,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -72,7 +70,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
__device__
constexpr
ThreadwiseDynamicTensorSliceTransfer_v1r3
(
__device__
constexpr
ThreadwiseDynamicTensorSliceTransfer_v1r3
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
:
dst_
slice_origin_
coord_
(
make_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
))
:
dst_coord_
(
make_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
))
{
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc need to known at compile-time"
);
"wrong! SrcDesc need to known at compile-time"
);
...
@@ -80,15 +78,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -80,15 +78,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
{
{
dst_
slice_origin_
coord_
=
make_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
dst_coord_
=
make_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
}
}
template
<
typename
SrcSliceOriginIdx
,
typename
SrcBuffer
,
typename
DstIteratorHacks
>
template
<
typename
SrcSliceOriginIdx
,
typename
SrcBuffer
,
typename
DstBuffer
,
typename
DstIteratorHacks
>
__device__
void
Run
(
const
SrcDesc
&
,
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcSliceOriginIdx
&
,
const
SrcSliceOriginIdx
&
,
const
SrcBuffer
&
src_buf
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
Dst
Data
*
p_dst
,
Dst
Buffer
&
dst_buf
,
const
DstIteratorHacks
&
dst_iterator_hacks
)
const
DstIteratorHacks
&
dst_iterator_hacks
)
{
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
...
@@ -191,12 +192,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -191,12 +192,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
return
dst_data_idx
;
return
dst_data_idx
;
}();
}();
// copy data
typename
vector_type_maker
<
DstData
,
DstScalarPerVector
>::
type
dst_vector
;
typename
vector_type_maker
<
DstData
,
DstScalarPerVector
>::
type
dst_vector
;
using
dst_vector_t
=
using
dst_vector_t
=
typename
vector_type_maker
<
DstData
,
DstScalarPerVector
>::
type
::
type
;
typename
vector_type_maker
<
DstData
,
DstScalarPerVector
>::
type
::
type
;
// copy data from src_buf into dst_vector
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
dst_data_idx
+
i
*
dst_scalar_step_in_vector
);
src_slice_origin_idx
+
dst_data_idx
+
i
*
dst_scalar_step_in_vector
);
...
@@ -205,37 +206,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -205,37 +206,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
type_convert
<
DstData
>
{}(
src_buf
[
Number
<
src_offset
>
{}]);
type_convert
<
DstData
>
{}(
src_buf
[
Number
<
src_offset
>
{}]);
});
});
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
const
bool
is_dst_valid
=
dst_desc
,
dst
_slice_origin
_coord_
);
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
if
constexpr
(
SrcAddressSpace
==
AddressSpace
::
Vgpr
&&
// copy data from dst_vector into dst_buf
DstAddressSpace
==
AddressSpace
::
Global
)
dst_buf
.
template
Set
<
dst_vector_t
>(
{
dst_coord_
.
GetOffset
(),
#if CK_USE_AMD_BUFFER_ADDRESSING
is_dst_valid
,
amd_buffer_store_v2
<
DstData
,
DstScalarPerVector
>
(
dst_vector
.
template
AsType
<
dst_vector_t
>()[
Number
<
0
>
{}]);
dst_vector
.
template
AsType
<
dst_vector_t
>()(
Number
<
0
>
{}),
p_dst
,
dst_slice_origin_coord_
.
GetOffset
(),
is_dst_valid
,
dst_desc
.
GetElementSpaceSize
());
#else
if
(
is_dst_valid
)
{
*
reinterpret_cast
<
dst_vector_t
*>
(
&
(
p_dst
[
dst_slice_origin_coord_
.
GetOffset
()]))
=
dst_vector
.
template
AsType
<
dst_vector_t
>()[
Number
<
0
>
{}];
}
#endif
}
else
{
if
(
is_dst_valid
)
{
*
reinterpret_cast
<
dst_vector_t
*>
(
&
(
p_dst
[
dst_slice_origin_coord_
.
GetOffset
()]))
=
dst_vector
.
template
AsType
<
dst_vector_t
>()[
Number
<
0
>
{}];
}
}
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
{
...
@@ -259,15 +237,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -259,15 +237,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
{
{
if
constexpr
(
forward_sweep
[
i
])
if
constexpr
(
forward_sweep
[
i
])
{
{
move_dynamic_tensor_coordinate
(
dst_desc
,
move_dynamic_tensor_coordinate
(
dst_slice_origin_coord_
,
dst_desc
,
dst_coord_
,
dst_forward_iterators
[
dim_access_order
[
i
]]);
dst_forward_iterators
[
dim_access_order
[
i
]]);
}
}
else
else
{
{
move_dynamic_tensor_coordinate
(
dst_desc
,
move_dynamic_tensor_coordinate
(
dst_slice_origin_coord_
,
dst_desc
,
dst_coord_
,
dst_backward_iterators
[
dim_access_order
[
i
]]);
dst_backward_iterators
[
dim_access_order
[
i
]]);
}
}
}
}
});
});
...
@@ -279,11 +255,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -279,11 +255,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
const
auto
dst_reset_iterator
=
const
auto
dst_reset_iterator
=
make_dynamic_tensor_coordinate_iterator
(
dst_desc
,
GetDstCoordinateResetStep
());
make_dynamic_tensor_coordinate_iterator
(
dst_desc
,
GetDstCoordinateResetStep
());
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_
slice_origin_
coord_
,
dst_reset_iterator
);
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_reset_iterator
);
}
}
}
}
__device__
void
Run
(
const
SrcData
*
p_src
,
const
DstDesc
&
dst_desc
,
DstData
*
p_dst
)
template
<
typename
SrcSliceOriginIdx
,
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcSliceOriginIdx
&
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
{
constexpr
index_t
ntransform_dst
=
DstDesc
::
GetNumOfTransform
();
constexpr
index_t
ntransform_dst
=
DstDesc
::
GetNumOfTransform
();
...
@@ -293,7 +274,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -293,7 +274,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
Run
(
p_src
,
dst_desc
,
p_
dst
,
dst_iterator_hacks
);
Run
(
SrcDesc
{},
SrcSliceOriginIdx
{},
src_buf
,
dst_desc
,
dst
_buf
,
dst_iterator_hacks
);
}
}
__device__
static
constexpr
auto
GetDstCoordinateResetStep
()
__device__
static
constexpr
auto
GetDstCoordinateResetStep
()
...
@@ -371,18 +352,22 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -371,18 +352,22 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
const
auto
adjusted_step
=
const
auto
adjusted_step
=
make_dynamic_tensor_coordinate_iterator
(
dst_desc
,
adjusted_step_idx
);
make_dynamic_tensor_coordinate_iterator
(
dst_desc
,
adjusted_step_idx
);
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_
slice_origin_
coord_
,
adjusted_step
);
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_coord_
,
adjusted_step
);
}
}
private:
private:
DstCoord
dst_
slice_origin_
coord_
;
DstCoord
dst_coord_
;
};
// namespace ck
};
// namespace ck
// Assume:
// Assume:
// 1. src_desc is not known at compile-time
// 1. src:
// 2. dst_desc is known at compile-time
// 1. SrcDesc is not known at compile-time
// 3. src_slice_origin_idx is not known at compile-time
// 2. SrcBuffer is DynamicBuffer
// 4. dst_slice_origin_idx is known at compile-time and it's 0
// 3. src_slice_origin_idx is not known at compile-time
// 2. dst:
// 1. DstDesc is known at compile-time
// 2. DstBuffer is StaticBuffer
// 3. dst_slice_origin_idx is known at compile-time
template
<
typename
SrcData
,
template
<
typename
SrcData
,
typename
DstData
,
typename
DstData
,
typename
SrcDesc
,
typename
SrcDesc
,
...
@@ -391,8 +376,6 @@ template <typename SrcData,
...
@@ -391,8 +376,6 @@ template <typename SrcData,
typename
DimAccessOrder
,
typename
DimAccessOrder
,
index_t
SrcVectorDim
,
index_t
SrcVectorDim
,
index_t
SrcScalarPerVector
,
index_t
SrcScalarPerVector
,
AddressSpace
SrcAddressSpace
,
AddressSpace
DstAddressSpace
,
index_t
SrcScalarStrideInVector
,
index_t
SrcScalarStrideInVector
,
bool
SrcResetCoordinateAfterRun
,
bool
SrcResetCoordinateAfterRun
,
typename
std
::
enable_if
<
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
typename
std
::
enable_if
<
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
...
@@ -408,7 +391,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
...
@@ -408,7 +391,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
__device__
constexpr
ThreadwiseDynamicTensorSliceTransfer_v2
(
const
SrcDesc
&
src_desc
,
__device__
constexpr
ThreadwiseDynamicTensorSliceTransfer_v2
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
const
Index
&
src_slice_origin_idx
)
:
src_
slice_origin_
coord_
(
make_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
))
:
src_coord_
(
make_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
))
{
{
static_assert
(
DstDesc
::
IsKnownAtCompileTime
(),
static_assert
(
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc need to known at compile-time"
);
"wrong! SrcDesc need to known at compile-time"
);
...
@@ -416,12 +399,15 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
...
@@ -416,12 +399,15 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
__device__
void
SetDstSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
__device__
void
SetDstSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
{
{
src_
slice_origin_
coord_
=
make_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
src_coord_
=
make_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
}
}
template
<
typename
DstBuffer
,
typename
DstSliceOriginIdx
,
typename
SrcIteratorHacks
>
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
typename
DstSliceOriginIdx
,
typename
SrcIteratorHacks
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
Src
Data
*
p_src
,
const
Src
Buffer
&
src_buf
,
const
DstDesc
&
,
const
DstDesc
&
,
const
DstSliceOriginIdx
&
,
const
DstSliceOriginIdx
&
,
DstBuffer
&
dst_buf
,
DstBuffer
&
dst_buf
,
...
@@ -525,41 +511,19 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
...
@@ -525,41 +511,19 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
return
src_data_idx
;
return
src_data_idx
;
}();
}();
// copy data
static_assert
(
DstAddressSpace
==
AddressSpace
::
Vgpr
,
"wrong! hardcode for vgpr dst"
);
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
>::
type
src_vector
;
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
>::
type
src_vector
;
using
src_vector_t
=
using
src_vector_t
=
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
>::
type
::
type
;
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
>::
type
::
type
;
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
const
bool
is_src_valid
=
src_desc
,
src
_slice_origin
_coord_
);
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
if
constexpr
(
SrcAddressSpace
==
AddressSpace
::
Global
)
// copy data from src_buf into src_vector
{
src_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
#if CK_USE_AMD_BUFFER_ADDRESSING
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
);
src_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
amd_buffer_load_v2
<
SrcData
,
SrcScalarPerVector
>
(
p_src
,
src_slice_origin_coord_
.
GetOffset
(),
is_src_valid
,
src_desc
.
GetElementSpaceSize
());
#else
src_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
is_src_valid
?
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_slice_origin_coord_
.
GetOffset
()])
:
src_vector_t
{
0
};
#endif
}
else
{
src_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
is_src_valid
?
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_slice_origin_coord_
.
GetOffset
()])
:
src_vector_t
{
0
};
}
// copy data from src_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
to_multi_index
(
dst_slice_origin_idx
)
+
src_data_idx
+
dst_desc
.
CalculateOffset
(
to_multi_index
(
dst_slice_origin_idx
)
+
src_data_idx
+
...
@@ -590,15 +554,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
...
@@ -590,15 +554,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
{
{
if
constexpr
(
forward_sweep
[
i
])
if
constexpr
(
forward_sweep
[
i
])
{
{
move_dynamic_tensor_coordinate
(
src_desc
,
move_dynamic_tensor_coordinate
(
src_slice_origin_coord_
,
src_desc
,
src_coord_
,
src_forward_iterators
[
dim_access_order
[
i
]]);
src_forward_iterators
[
dim_access_order
[
i
]]);
}
}
else
else
{
{
move_dynamic_tensor_coordinate
(
src_desc
,
move_dynamic_tensor_coordinate
(
src_slice_origin_coord_
,
src_desc
,
src_coord_
,
src_backward_iterators
[
dim_access_order
[
i
]]);
src_backward_iterators
[
dim_access_order
[
i
]]);
}
}
}
}
});
});
...
@@ -610,13 +572,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
...
@@ -610,13 +572,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
const
auto
src_reset_iterator
=
const
auto
src_reset_iterator
=
make_dynamic_tensor_coordinate_iterator
(
src_desc
,
GetSrcCoordinateResetStep
());
make_dynamic_tensor_coordinate_iterator
(
src_desc
,
GetSrcCoordinateResetStep
());
move_dynamic_tensor_coordinate
(
src_desc
,
src_
slice_origin_
coord_
,
src_reset_iterator
);
move_dynamic_tensor_coordinate
(
src_desc
,
src_coord_
,
src_reset_iterator
);
}
}
}
}
template
<
typename
DstBuffer
,
typename
DstSliceOriginIdx
>
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
typename
DstSliceOriginIdx
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
Src
Data
*
p_src
,
const
Src
Buffer
&
src_buf
,
const
DstDesc
&
,
const
DstDesc
&
,
const
DstSliceOriginIdx
&
,
const
DstSliceOriginIdx
&
,
DstBuffer
&
dst_buf
)
DstBuffer
&
dst_buf
)
...
@@ -629,7 +591,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
...
@@ -629,7 +591,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
Run
(
src_desc
,
p_
src
,
DstDesc
{},
DstSliceOriginIdx
{},
dst_buf
,
src_iterator_hacks
);
Run
(
src_desc
,
src
_buf
,
DstDesc
{},
DstSliceOriginIdx
{},
dst_buf
,
src_iterator_hacks
);
}
}
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
...
@@ -707,17 +669,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
...
@@ -707,17 +669,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
const
auto
adjusted_step
=
const
auto
adjusted_step
=
make_dynamic_tensor_coordinate_iterator
(
src_desc
,
adjusted_step_idx
);
make_dynamic_tensor_coordinate_iterator
(
src_desc
,
adjusted_step_idx
);
move_dynamic_tensor_coordinate
(
src_desc
,
src_
slice_origin_
coord_
,
adjusted_step
);
move_dynamic_tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
}
}
private:
private:
SrcCoord
src_
slice_origin_
coord_
;
SrcCoord
src_coord_
;
};
// namespace ck
};
// namespace ck
// Assume:
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 1. src_desc and dst_desc are not known at compile-time
// 2. src_slice_origin and dst_slice_origin are not known at compile-time,
// 2. SrcBuffer and DstBuffer are DynamicBuffer
// 3. Use thread buffer
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
// 4. Use thread buffer
template
<
typename
SliceLengths
,
template
<
typename
SliceLengths
,
InMemoryDataOperation
DstInMemOp
,
InMemoryDataOperation
DstInMemOp
,
typename
SrcData
,
typename
SrcData
,
...
@@ -732,8 +695,6 @@ template <typename SliceLengths,
...
@@ -732,8 +695,6 @@ template <typename SliceLengths,
index_t
DstScalarPerVector
,
index_t
DstScalarPerVector
,
index_t
SrcScalarStrideInVector
,
index_t
SrcScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
AddressSpace
SrcAddressSpace
,
AddressSpace
DstAddressSpace
,
bool
SrcResetCoordinateAfterRun
,
// control whether to move back src coordinate after each
bool
SrcResetCoordinateAfterRun
,
// control whether to move back src coordinate after each
// RunRead(), will be fused with MoveSrcSliceWindow to
// RunRead(), will be fused with MoveSrcSliceWindow to
// save addr computation
// save addr computation
...
@@ -755,16 +716,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -755,16 +716,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
const
Index
&
src_slice_origin
,
const
Index
&
src_slice_origin
,
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin
)
const
Index
&
dst_slice_origin
)
:
src_
slice_origin_
coord_
(
make_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin
)),
:
src_coord_
(
make_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin
)),
dst_
slice_origin_
coord_
(
make_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin
))
dst_coord_
(
make_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin
))
{
{
static_assert
(
SrcAddressSpace
==
AddressSpace
::
Global
or
SrcAddressSpace
==
AddressSpace
::
Lds
,
"wrong!"
);
static_assert
(
DstAddressSpace
==
AddressSpace
::
Global
or
DstAddressSpace
==
AddressSpace
::
Lds
,
"wrong!"
);
// TODO: fix this
// TODO: fix this
static_assert
(
is_same
<
SrcData
,
DstData
>::
value
,
static_assert
(
is_same
<
SrcData
,
DstData
>::
value
,
"wrong! current implementation assume SrcData and DstData are same type"
);
"wrong! current implementation assume SrcData and DstData are same type"
);
...
@@ -772,19 +726,27 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -772,19 +726,27 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
{
{
src_
slice_origin_
coord_
=
make_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
src_coord_
=
make_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
}
}
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
{
{
dst_
slice_origin_
coord_
=
make_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
dst_coord_
=
make_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
}
}
template
<
typename
SrcIteratorHacks
>
template
<
typename
SrcBuffer
,
typename
SrcIteratorHacks
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
Src
Data
*
p_src
,
const
Src
Buffer
&
src_buf
,
const
SrcIteratorHacks
&
src_iterator_hacks
)
const
SrcIteratorHacks
&
src_iterator_hacks
)
{
{
static_assert
(
SrcBuffer
::
GetAddressSpace
()
==
AddressSpace
::
Global
or
SrcBuffer
::
GetAddressSpace
()
==
AddressSpace
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
SrcBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
SrcData
>>>::
value
,
"wrong! SrcBuffer and SrcData data type are inconsistent"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -869,37 +831,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -869,37 +831,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
return
src_data_idx
;
return
src_data_idx
;
}();
}();
// copy data from src_buf to src_tmp_vector
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
src_tmp_vector
;
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
src_tmp_vector
;
using
src_vector_t
=
typename
decltype
(
src_tmp_vector
)
::
type
;
using
src_vector_t
=
typename
decltype
(
src_tmp_vector
)
::
type
;
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
const
bool
is_src_valid
=
src_desc
,
src
_slice_origin
_coord_
);
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
if
constexpr
(
SrcAddressSpace
==
AddressSpace
::
Global
)
// copy data from src_buf to src_tmp_vector
{
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
#if CK_USE_AMD_BUFFER_ADDRESSING
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
);
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
amd_buffer_load_v2
<
SrcData
,
SrcScalarPerVector
>
(
p_src
,
src_slice_origin_coord_
.
GetOffset
(),
is_src_valid
,
src_desc
.
GetElementSpaceSize
());
#else
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
is_src_valid
?
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_slice_origin_coord_
.
GetOffset
()])
:
src_vector_t
{
0
};
#endif
}
else
{
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
is_src_valid
?
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_slice_origin_coord_
.
GetOffset
()])
:
src_vector_t
{
0
};
}
// copy data from src_tmp_vector to buffer_
// copy data from src_tmp_vector to buffer_
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
...
@@ -933,16 +874,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -933,16 +874,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
if
constexpr
(
forward_sweep
[
i
])
if
constexpr
(
forward_sweep
[
i
])
{
{
move_dynamic_tensor_coordinate
(
move_dynamic_tensor_coordinate
(
src_desc
,
src_desc
,
src_coord_
,
src_forward_iterators
[
src_dim_access_order
[
i
]]);
src_slice_origin_coord_
,
src_forward_iterators
[
src_dim_access_order
[
i
]]);
}
}
else
else
{
{
move_dynamic_tensor_coordinate
(
move_dynamic_tensor_coordinate
(
src_desc
,
src_desc
,
src_coord_
,
src_backward_iterators
[
src_dim_access_order
[
i
]]);
src_slice_origin_coord_
,
src_backward_iterators
[
src_dim_access_order
[
i
]]);
}
}
}
}
});
});
...
@@ -954,14 +891,23 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -954,14 +891,23 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
const
auto
src_reset_iterator
=
const
auto
src_reset_iterator
=
make_dynamic_tensor_coordinate_iterator
(
src_desc
,
GetSrcCoordinateResetStep
());
make_dynamic_tensor_coordinate_iterator
(
src_desc
,
GetSrcCoordinateResetStep
());
move_dynamic_tensor_coordinate
(
src_desc
,
src_
slice_origin_
coord_
,
src_reset_iterator
);
move_dynamic_tensor_coordinate
(
src_desc
,
src_coord_
,
src_reset_iterator
);
}
}
}
}
template
<
typename
DstIteratorHacks
>
template
<
typename
DstBuffer
,
typename
DstIteratorHacks
>
__device__
void
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
RunWrite
(
const
DstDesc
&
dst_desc
,
DstData
*
p_dst
,
const
DstIteratorHacks
&
dst_iterator_hacks
)
DstBuffer
&
dst_buf
,
const
DstIteratorHacks
&
dst_iterator_hacks
)
{
{
static_assert
(
DstBuffer
::
GetAddressSpace
()
==
AddressSpace
::
Global
or
DstBuffer
::
GetAddressSpace
()
==
AddressSpace
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
DstBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
DstData
>>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -1050,13 +996,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -1050,13 +996,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
return
dst_data_idx
;
return
dst_data_idx
;
}();
}();
// copy data
// hardcoding for ds_write
// TODO refactor transfer_data() to encapsulate this
static_assert
(
DstAddressSpace
==
AddressSpace
::
Lds
&&
DstInMemOp
==
InMemoryDataOperation
::
Set
,
"wrong! hardcoded for ds_write"
);
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
dst_tmp_vector
;
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
dst_tmp_vector
;
// copy data from buffer_ to dst_tmp_vector
// copy data from buffer_ to dst_tmp_vector
...
@@ -1070,8 +1009,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -1070,8 +1009,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
using
dst_vector_t
=
typename
decltype
(
dst_tmp_vector
)
::
type
;
using
dst_vector_t
=
typename
decltype
(
dst_tmp_vector
)
::
type
;
// copy data from dst_tmp_vector to dst_buf
// copy data from dst_tmp_vector to dst_buf
*
reinterpret_cast
<
dst_vector_t
*>
(
p_dst
+
dst_slice_origin_coord_
.
GetOffset
())
=
const
bool
is_dst_valid
=
dst_tmp_vector
.
template
AsType
<
dst_vector_t
>()[
Number
<
0
>
{}];
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
dst_buf
.
template
Set
<
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_tmp_vector
.
template
AsType
<
dst_vector_t
>()[
Number
<
0
>
{}]);
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
{
...
@@ -1097,16 +1041,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -1097,16 +1041,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
if
constexpr
(
forward_sweep
[
i
])
if
constexpr
(
forward_sweep
[
i
])
{
{
move_dynamic_tensor_coordinate
(
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_desc
,
dst_coord_
,
dst_forward_iterators
[
dst_dim_access_order
[
i
]]);
dst_slice_origin_coord_
,
dst_forward_iterators
[
dst_dim_access_order
[
i
]]);
}
}
else
else
{
{
move_dynamic_tensor_coordinate
(
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_desc
,
dst_coord_
,
dst_backward_iterators
[
dst_dim_access_order
[
i
]]);
dst_slice_origin_coord_
,
dst_backward_iterators
[
dst_dim_access_order
[
i
]]);
}
}
}
}
});
});
...
@@ -1118,11 +1058,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -1118,11 +1058,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
const
auto
dst_reset_iterator
=
const
auto
dst_reset_iterator
=
make_dynamic_tensor_coordinate_iterator
(
dst_desc
,
GetDstCoordinateResetStep
());
make_dynamic_tensor_coordinate_iterator
(
dst_desc
,
GetDstCoordinateResetStep
());
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_
slice_origin_
coord_
,
dst_reset_iterator
);
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_reset_iterator
);
}
}
}
}
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcData
*
p_src
)
template
<
typename
SrcBuffer
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
)
{
{
constexpr
index_t
ntransform_src
=
SrcDesc
::
GetNumOfTransform
();
constexpr
index_t
ntransform_src
=
SrcDesc
::
GetNumOfTransform
();
...
@@ -1132,10 +1073,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -1132,10 +1073,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
RunRead
(
src_desc
,
p_
src
,
src_iterator_hacks
);
RunRead
(
src_desc
,
src
_buf
,
src_iterator_hacks
);
}
}
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstData
*
p_dst
)
template
<
typename
DstBuffer
>
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
{
constexpr
index_t
ntransform_dst
=
DstDesc
::
GetNumOfTransform
();
constexpr
index_t
ntransform_dst
=
DstDesc
::
GetNumOfTransform
();
...
@@ -1145,7 +1087,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -1145,7 +1087,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
RunWrite
(
dst_desc
,
p_
dst
,
dst_iterator_hacks
);
RunWrite
(
dst_desc
,
dst
_buf
,
dst_iterator_hacks
);
}
}
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
...
@@ -1285,7 +1227,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -1285,7 +1227,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
const
auto
adjusted_step
=
const
auto
adjusted_step
=
make_dynamic_tensor_coordinate_iterator
(
src_desc
,
adjusted_step_idx
);
make_dynamic_tensor_coordinate_iterator
(
src_desc
,
adjusted_step_idx
);
move_dynamic_tensor_coordinate
(
src_desc
,
src_
slice_origin_
coord_
,
adjusted_step
);
move_dynamic_tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
...
@@ -1304,7 +1246,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -1304,7 +1246,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
const
auto
adjusted_step
=
make_dynamic_tensor_coordinate_iterator
(
const
auto
adjusted_step
=
make_dynamic_tensor_coordinate_iterator
(
src_desc
,
adjusted_step_idx
,
src_move_slice_window_iterator_hack
);
src_desc
,
adjusted_step_idx
,
src_move_slice_window_iterator_hack
);
move_dynamic_tensor_coordinate
(
src_desc
,
src_
slice_origin_
coord_
,
adjusted_step
);
move_dynamic_tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
}
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
...
@@ -1319,7 +1261,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -1319,7 +1261,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
const
auto
adjusted_step
=
const
auto
adjusted_step
=
make_dynamic_tensor_coordinate_iterator
(
dst_desc
,
adjusted_step_idx
);
make_dynamic_tensor_coordinate_iterator
(
dst_desc
,
adjusted_step_idx
);
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_
slice_origin_
coord_
,
adjusted_step
);
move_dynamic_tensor_coordinate
(
dst_desc
,
dst_coord_
,
adjusted_step
);
}
}
private:
private:
...
@@ -1328,10 +1270,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -1328,10 +1270,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
static
constexpr
auto
buffer_size_
=
buffer_desc_
.
GetElementSpaceSize
();
static
constexpr
auto
buffer_size_
=
buffer_desc_
.
GetElementSpaceSize
();
StaticBuffer
<
SrcData
,
buffer_size_
>
buffer_
;
StaticBuffer
<
AddressSpace
::
Vgpr
,
SrcData
,
buffer_size_
>
buffer_
;
SrcCoord
src_
slice_origin_
coord_
;
SrcCoord
src_coord_
;
DstCoord
dst_
slice_origin_
coord_
;
DstCoord
dst_coord_
;
};
};
// Assume:
// Assume:
...
@@ -1356,8 +1298,6 @@ template <
...
@@ -1356,8 +1298,6 @@ template <
typename
DimAccessOrder
,
typename
DimAccessOrder
,
index_t
SrcVectorDim
,
index_t
SrcVectorDim
,
index_t
SrcScalarPerVector
,
index_t
SrcScalarPerVector
,
AddressSpace
SrcAddressSpace
,
AddressSpace
DstAddressSpace
,
index_t
SrcScalarStrideInVector
,
index_t
SrcScalarStrideInVector
,
typename
std
::
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
typename
std
::
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
bool
>
::
type
=
false
>
...
@@ -1480,7 +1420,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
...
@@ -1480,7 +1420,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
move_dynamic_tensor_coordinate
(
move_dynamic_tensor_coordinate
(
src_desc
,
src_data_coord
,
src_ref_to_data_disp_coord_iterator
);
src_desc
,
src_data_coord
,
src_ref_to_data_disp_coord_iterator
);
// copy data from src_buf into src_tmp_buffer
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
src_tmp_vector
;
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
src_tmp_vector
;
using
src_vector_t
=
typename
decltype
(
src_tmp_vector
)
::
type
;
using
src_vector_t
=
typename
decltype
(
src_tmp_vector
)
::
type
;
...
@@ -1488,9 +1427,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
...
@@ -1488,9 +1427,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_data_coord
);
src_desc
,
src_data_coord
);
// copy data from src_buf into src_tmp_vector
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
is_src_valid
?
src_buf
.
template
Get
<
src_vector_t
>(
src_data_coord
.
GetOffset
())
src_buf
.
template
Get
<
src_vector_t
>(
src_data_coord
.
GetOffset
(),
is_src_valid
);
:
src_vector_t
{
0
};
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
// DstData)
...
...
composable_kernel/include/utility/amd_buffer_addressing_v2.hpp
View file @
be49a8c5
...
@@ -323,7 +323,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
...
@@ -323,7 +323,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
}
}
else
if
constexpr
(
N
==
2
)
else
if
constexpr
(
N
==
2
)
{
{
#if !CK_WORKAROUND_SWDEV_XXXXXX
#if !CK_WORKAROUND_SWDEV_XXXXXX
_INT8_BUFFER_LOAD_STORE_ISSUE
return
__llvm_amdgcn_raw_buffer_load_i8x2
(
return
__llvm_amdgcn_raw_buffer_load_i8x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
#else
#else
...
@@ -335,7 +335,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
...
@@ -335,7 +335,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
}
}
else
if
constexpr
(
N
==
4
)
else
if
constexpr
(
N
==
4
)
{
{
#if !CK_WORKAROUND_SWDEV_XXXXXX
#if !CK_WORKAROUND_SWDEV_XXXXXX
_INT8_BUFFER_LOAD_STORE_ISSUE
return
__llvm_amdgcn_raw_buffer_load_i8x4
(
return
__llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
);
#else
#else
...
@@ -347,7 +347,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
...
@@ -347,7 +347,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
}
}
else
if
constexpr
(
N
==
8
)
else
if
constexpr
(
N
==
8
)
{
{
#if !CK_WORKAROUND_SWDEV_XXXXXX
#if !CK_WORKAROUND_SWDEV_XXXXXX
_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type
<
int8_t
,
8
>
tmp
;
vector_type
<
int8_t
,
8
>
tmp
;
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
0
>
{})
=
__llvm_amdgcn_raw_buffer_load_i8x4
(
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
0
>
{})
=
__llvm_amdgcn_raw_buffer_load_i8x4
(
...
@@ -369,7 +369,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
...
@@ -369,7 +369,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
}
}
else
if
constexpr
(
N
==
16
)
else
if
constexpr
(
N
==
16
)
{
{
#if !CK_WORKAROUND_SWDEV_XXXXXX
#if !CK_WORKAROUND_SWDEV_XXXXXX
_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type
<
int8_t
,
16
>
tmp
;
vector_type
<
int8_t
,
16
>
tmp
;
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
0
>
{})
=
__llvm_amdgcn_raw_buffer_load_i8x4
(
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
0
>
{})
=
__llvm_amdgcn_raw_buffer_load_i8x4
(
...
@@ -483,7 +483,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
...
@@ -483,7 +483,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
}
}
else
if
constexpr
(
N
==
2
)
else
if
constexpr
(
N
==
2
)
{
{
#if !CK_WORKAROUND_SWDEV_XXXXXX
#if !CK_WORKAROUND_SWDEV_XXXXXX
_INT8_BUFFER_LOAD_STORE_ISSUE
__llvm_amdgcn_raw_buffer_store_i8x2
(
src_thread_data
,
__llvm_amdgcn_raw_buffer_store_i8x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_thread_addr_offset
,
...
@@ -499,7 +499,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
...
@@ -499,7 +499,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
}
}
else
if
constexpr
(
N
==
4
)
else
if
constexpr
(
N
==
4
)
{
{
#if !CK_WORKAROUND_SWDEV_XXXXXX
#if !CK_WORKAROUND_SWDEV_XXXXXX
_INT8_BUFFER_LOAD_STORE_ISSUE
__llvm_amdgcn_raw_buffer_store_i8x4
(
src_thread_data
,
__llvm_amdgcn_raw_buffer_store_i8x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_thread_addr_offset
,
...
...
composable_kernel/include/utility/buffer.hpp
deleted
100644 → 0
View file @
bcdc330d
#ifndef CK_BUFFER_HPP
#define CK_BUFFER_HPP
#include "statically_indexed_array.hpp"
namespace
ck
{
template
<
typename
T
,
index_t
N
>
struct
StaticBuffer
:
public
StaticallyIndexedArray
<
T
,
N
>
{
using
type
=
T
;
using
base
=
StaticallyIndexedArray
<
T
,
N
>
;
__host__
__device__
constexpr
StaticBuffer
()
:
base
{}
{}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
};
template
<
typename
T
,
index_t
N
>
__host__
__device__
constexpr
auto
make_static_buffer
(
Number
<
N
>
)
{
return
StaticBuffer
<
T
,
N
>
{};
}
template
<
typename
T
>
struct
DynamicBuffer
{
using
type
=
T
;
T
*
p_data_
;
__host__
__device__
constexpr
DynamicBuffer
(
T
*
p_data
)
:
p_data_
{
p_data
}
{}
__host__
__device__
constexpr
const
T
&
operator
[](
index_t
i
)
const
{
return
p_data_
[
i
];
}
__host__
__device__
constexpr
T
&
operator
()(
index_t
i
)
{
return
p_data_
[
i
];
}
template
<
typename
X
,
typename
std
::
enable_if
<
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
const
auto
Get
(
index_t
i
)
const
{
return
*
reinterpret_cast
<
const
X
*>
(
&
p_data_
[
i
]);
}
template
<
typename
X
,
typename
std
::
enable_if
<
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
void
Set
(
index_t
i
,
const
X
&
x
)
{
*
reinterpret_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
true
;
}
};
template
<
typename
T
>
__host__
__device__
constexpr
auto
make_dynamic_buffer
(
T
*
p
)
{
return
DynamicBuffer
<
T
>
{
p
};
}
}
// namespace ck
#endif
composable_kernel/include/utility/common_header.hpp
View file @
be49a8c5
...
@@ -8,7 +8,6 @@
...
@@ -8,7 +8,6 @@
#include "container_element_picker.hpp"
#include "container_element_picker.hpp"
#include "data_type.hpp"
#include "data_type.hpp"
#include "float_type.hpp"
#include "float_type.hpp"
#include "buffer.hpp"
#include "functional.hpp"
#include "functional.hpp"
#include "functional2.hpp"
#include "functional2.hpp"
#include "functional3.hpp"
#include "functional3.hpp"
...
@@ -25,6 +24,8 @@
...
@@ -25,6 +24,8 @@
#include "type.hpp"
#include "type.hpp"
#include "utility.hpp"
#include "utility.hpp"
#include "magic_division.hpp"
#include "magic_division.hpp"
#include "static_buffer.hpp"
#include "dynamic_buffer.hpp"
#if CK_USE_AMD_INLINE_ASM
#if CK_USE_AMD_INLINE_ASM
#include "amd_inline_asm.hpp"
#include "amd_inline_asm.hpp"
...
...
composable_kernel/include/utility/config.amd.hpp.in
View file @
be49a8c5
...
@@ -143,8 +143,13 @@
...
@@ -143,8 +143,13 @@
#endif
#endif
// workaround for compiler crash when using buffer load/store for i8
// workaround for compiler crash when using buffer load/store for i8
#ifndef CK_WORKAROUND_SWDEV_XXXXXX
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX 1
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1
#endif
// workaround for compiler crash when using buffer load/store for i8
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif
#endif
namespace ck {
namespace ck {
...
@@ -154,6 +159,7 @@ enum AddressSpace
...
@@ -154,6 +159,7 @@ enum AddressSpace
Generic,
Generic,
Global,
Global,
Lds,
Lds,
Sgpr,
Vgpr
Vgpr
};
};
...
...
composable_kernel/include/utility/dynamic_buffer.hpp
0 → 100644
View file @
be49a8c5
#ifndef CK_DYNAMIC_BUFFER_HPP
#define CK_DYNAMIC_BUFFER_HPP
namespace
ck
{
#include "amd_buffer_addressing_v2.hpp"
template
<
AddressSpace
BufferAddressSpace
,
typename
T
,
typename
ElementSpaceSize
>
struct
DynamicBuffer
{
using
type
=
T
;
T
*
p_data_
;
ElementSpaceSize
element_space_size_
;
__host__
__device__
constexpr
DynamicBuffer
(
T
*
p_data
,
ElementSpaceSize
element_space_size
)
:
p_data_
{
p_data
},
element_space_size_
{
element_space_size
}
{
}
__host__
__device__
static
constexpr
AddressSpace
GetAddressSpace
()
{
return
BufferAddressSpace
;
}
__host__
__device__
constexpr
const
T
&
operator
[](
index_t
i
)
const
{
return
p_data_
[
i
];
}
__host__
__device__
constexpr
T
&
operator
()(
index_t
i
)
{
return
p_data_
[
i
];
}
template
<
typename
X
,
typename
std
::
enable_if
<
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
const
auto
Get
(
index_t
i
,
bool
is_valid_offset
)
const
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X need to be multiple T"
);
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
if
constexpr
(
GetAddressSpace
()
==
AddressSpace
::
Global
)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
return
amd_buffer_load_v2
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
t_per_x
>
(
p_data_
,
i
,
is_valid_offset
,
element_space_size_
);
#else
return
is_valid_offset
?
*
reinterpret_cast
<
const
X
*>
(
&
p_data_
[
i
])
:
X
{
0
};
#endif
}
else
{
return
is_valid_offset
?
*
reinterpret_cast
<
const
X
*>
(
&
p_data_
[
i
])
:
X
{
0
};
}
}
template
<
typename
X
,
typename
std
::
enable_if
<
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
void
Set
(
index_t
i
,
bool
is_valid_offset
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X need to be multiple T"
);
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
if
constexpr
(
GetAddressSpace
()
==
AddressSpace
::
Global
)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
amd_buffer_store_v2
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_offset
,
element_space_size_
);
#else
if
(
is_valid_offset
)
{
*
reinterpret_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
}
#endif
}
else
if
constexpr
(
GetAddressSpace
()
==
AddressSpace
::
Lds
)
{
if
(
is_valid_offset
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
*
reinterpret_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
#else
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient
// ISA, so I try to let compiler emit use IR "store<i32, 4>" which would be lower to
// ds_write_b128
// TODO: remove this after compiler fix
if
constexpr
(
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
,
int8_t
>::
value
)
{
static_assert
(
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x4_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x4_t
>::
value
)
||
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x8_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x8_t
>::
value
)
||
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x16_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x16_t
>::
value
),
"wrong! not implemented for this combination, please add implementation"
);
if
constexpr
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x4_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x4_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
reinterpret_cast
<
int32_t
*>
(
&
p_data_
[
i
])
=
*
reinterpret_cast
<
const
int32_t
*>
(
&
x
);
}
if
constexpr
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x8_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x8_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
reinterpret_cast
<
int32x2_t
*>
(
&
p_data_
[
i
])
=
*
reinterpret_cast
<
const
int32x2_t
*>
(
&
x
);
}
if
constexpr
(
is_same
<
remove_cv_t
<
remove_reference_t
<
T
>>
,
int8x16_t
>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
X
>>
,
int8x16_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
reinterpret_cast
<
int32x4_t
*>
(
&
p_data_
[
i
])
=
*
reinterpret_cast
<
const
int32x4_t
*>
(
&
x
);
}
}
else
{
*
reinterpret_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
}
#endif
}
}
else
{
if
(
is_valid_offset
)
{
*
reinterpret_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
}
}
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
true
;
}
};
template
<
AddressSpace
BufferAddressSpace
=
AddressSpace
::
Generic
,
typename
T
,
typename
ElementSpaceSize
>
__host__
__device__
constexpr
auto
make_dynamic_buffer
(
T
*
p
,
ElementSpaceSize
element_space_size
)
{
return
DynamicBuffer
<
BufferAddressSpace
,
T
,
ElementSpaceSize
>
{
p
,
element_space_size
};
}
}
// namespace ck
#endif
composable_kernel/include/utility/static_buffer.hpp
0 → 100644
View file @
be49a8c5
#ifndef CK_STATIC_BUFFER_HPP
#define CK_STATIC_BUFFER_HPP
#include "statically_indexed_array.hpp"
namespace
ck
{
template
<
AddressSpace
BufferAddressSpace
,
typename
T
,
index_t
N
>
struct
StaticBuffer
:
public
StaticallyIndexedArray
<
T
,
N
>
{
using
type
=
T
;
using
base
=
StaticallyIndexedArray
<
T
,
N
>
;
__host__
__device__
constexpr
StaticBuffer
()
:
base
{}
{}
__host__
__device__
static
constexpr
AddressSpace
GetAddressSpace
()
{
return
BufferAddressSpace
;
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
};
template
<
AddressSpace
BufferAddressSpace
=
AddressSpace
::
Generic
,
typename
T
,
index_t
N
>
__host__
__device__
constexpr
auto
make_static_buffer
(
Number
<
N
>
)
{
return
StaticBuffer
<
BufferAddressSpace
,
T
,
N
>
{};
}
}
// namespace ck
#endif
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
be49a8c5
#include <unistd.h>
#include <unistd.h>
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "driver_dynamic_gemm_v1.hpp"
template
<
class
TInWei
,
template
<
class
TInWei
,
ck
::
index_t
InWeiVectorSize
,
ck
::
index_t
InWeiVectorSize
,
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
View file @
be49a8c5
#include <unistd.h>
#include <unistd.h>
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_dynamic_gemm_v1.hpp"
template
<
class
TInWei
,
template
<
class
TInWei
,
ck
::
index_t
InWeiVectorSize
,
ck
::
index_t
InWeiVectorSize
,
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
View file @
be49a8c5
...
@@ -63,12 +63,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
...
@@ -63,12 +63,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
#if 0
#if 0
// run-time variables
// run-time variables
const auto in_n_c_hi_wi_desc =
const auto in_n_c
0
_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
to
_multi_index(
InDesc::GetLengths()
));
make_dynamic_naive_tensor_descriptor_packed_v2(
make
_multi_index(
N, C0, Hi, Wi
));
const auto wei_k_c_y_x_desc =
const auto wei_k_c
0
_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
to
_multi_index(
WeiDesc::GetLengths()
));
make_dynamic_naive_tensor_descriptor_packed_v2(
make
_multi_index(
K, C0, Y, X
));
const auto out_n_k_ho_wo_desc =
const auto out_n_k
0
_ho_wo_
k1_
desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
to
_multi_index(
OutDesc::GetLengths()
));
make_dynamic_naive_tensor_descriptor_packed_v2(
make
_multi_index(
N, K0, Ho, Wo, K1
));
const auto conv_strides = to_multi_index(ConvStrides{});
const auto conv_strides = to_multi_index(ConvStrides{});
const auto conv_dilations = to_multi_index(ConvDilations{});
const auto conv_dilations = to_multi_index(ConvDilations{});
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment