Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b4d598bd
Commit
b4d598bd
authored
Jun 02, 2021
by
Chao Liu
Browse files
refactor
parent
06ba0a90
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
115 additions
and
80 deletions
+115
-80
composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
...osable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
+72
-44
composable_kernel/include/tensor_operation/blockwise_gemm_v2r2.hpp
...e_kernel/include/tensor_operation/blockwise_gemm_v2r2.hpp
+5
-4
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
...l/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
+38
-32
No files found.
composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp
View file @
b4d598bd
...
@@ -118,19 +118,27 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
...
@@ -118,19 +118,27 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
if
(
!
GridwiseGemm
::
CheckValidity
(
a_k_m_grid_desc
,
b_k_n_grid_desc
,
c_m_n_grid_desc
))
if
(
!
GridwiseGemm
::
CheckValidity
(
a_k_m_grid_desc
,
b_k_n_grid_desc
,
c_m_n_grid_desc
))
{
{
throw
std
::
runtime_error
(
"wrong! GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 has invalid setting"
);
}
}
const
auto
a_k_m0_m1_grid_desc
=
GridwiseGemm
::
MakeAKM0M1GridDescriptor
(
a_k_m_grid_desc
);
const
auto
b_k_n0_n1_grid_desc
=
GridwiseGemm
::
MakeBKN0N1GridDescriptor
(
b_k_n_grid_desc
);
using
AKM0M1GridDesc
=
decltype
(
a_k_m0_m1_grid_desc
);
using
BKN0N1GridDesc
=
decltype
(
b_k_n0_n1_grid_desc
);
// c_m0_m10_m11_n0_n10_n11_grid_desc
// c_m0_m10_m11_n0_n10_n11_grid_desc
const
auto
c_m0_m10_m11_n0_n10_n11_grid_desc
=
const
auto
c_m0_m10_m11_n0_n10_n11_grid_desc
=
GridwiseGemm
::
MakeCM0M10M11N0N10N11GridDescriptor
(
c_m_n_grid_desc
);
GridwiseGemm
::
MakeCM0M10M11N0N10N11GridDescriptor
(
c_m_n_grid_desc
);
using
CM0M10M11N0N10N11GridDesc
=
decltype
(
c_m0_m10_m11_n0_n10_n11_grid_desc
);
using
CM0M10M11N0N10N11GridDesc
=
decltype
(
c_m0_m10_m11_n0_n10_n11_grid_desc
);
// c_block_cluster_adaptor
// c_blockid_to_m0_n0_block_cluster_adaptor
const
auto
c_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_m_n_grid_desc
);
const
auto
c_blockid_to_m0_n0_block_cluster_adaptor
=
GridwiseGemm
::
MakeCBlockIdToM0N0BlockClusterAdaptor
(
c_m_n_grid_desc
);
using
CBlockClusterAdaptor
=
decltype
(
c_block_cluster_adaptor
);
using
CBlock
IdToM0N0Block
ClusterAdaptor
=
decltype
(
c_
blockid_to_m0_n0_
block_cluster_adaptor
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
M
,
N
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
M
,
N
);
...
@@ -142,15 +150,18 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
...
@@ -142,15 +150,18 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
kernel_dynamic_gemm_v1r2
<
GridwiseGemm
,
const
auto
kernel
=
FloatAB
,
kernel_dynamic_gemm_v1r2
<
GridwiseGemm
,
FloatC
,
FloatAB
,
remove_reference_t
<
AKMGridDesc
>
,
FloatC
,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
AKMGridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
remove_reference_t
<
AKM0M1GridDesc
>
,
true
,
remove_reference_t
<
BKN0N1GridDesc
>
,
true
>
;
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CBlockIdToM0N0BlockClusterAdaptor
>
,
true
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
...
@@ -163,20 +174,25 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
...
@@ -163,20 +174,25 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_c_grid
,
p_c_grid
,
a_k_m_grid_desc
,
a_k_m_grid_desc
,
b_k_n_grid_desc
,
b_k_n_grid_desc
,
a_k_m0_m1_grid_desc
,
b_k_n0_n1_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_block_cluster_adaptor
);
c_
blockid_to_m0_n0_
block_cluster_adaptor
);
}
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
kernel_dynamic_gemm_v1r2
<
GridwiseGemm
,
const
auto
kernel
=
FloatAB
,
kernel_dynamic_gemm_v1r2
<
GridwiseGemm
,
FloatC
,
FloatAB
,
remove_reference_t
<
AKMGridDesc
>
,
FloatC
,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
AKMGridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
remove_reference_t
<
AKM0M1GridDesc
>
,
true
,
remove_reference_t
<
BKN0N1GridDesc
>
,
false
>
;
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CBlockIdToM0N0BlockClusterAdaptor
>
,
true
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
...
@@ -189,20 +205,25 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
...
@@ -189,20 +205,25 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_c_grid
,
p_c_grid
,
a_k_m_grid_desc
,
a_k_m_grid_desc
,
b_k_n_grid_desc
,
b_k_n_grid_desc
,
a_k_m0_m1_grid_desc
,
b_k_n0_n1_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_block_cluster_adaptor
);
c_
blockid_to_m0_n0_
block_cluster_adaptor
);
}
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
kernel_dynamic_gemm_v1r2
<
GridwiseGemm
,
const
auto
kernel
=
FloatAB
,
kernel_dynamic_gemm_v1r2
<
GridwiseGemm
,
FloatC
,
FloatAB
,
remove_reference_t
<
AKMGridDesc
>
,
FloatC
,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
AKMGridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
remove_reference_t
<
AKM0M1GridDesc
>
,
false
,
remove_reference_t
<
BKN0N1GridDesc
>
,
true
>
;
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CBlockIdToM0N0BlockClusterAdaptor
>
,
false
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
...
@@ -215,20 +236,25 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
...
@@ -215,20 +236,25 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_c_grid
,
p_c_grid
,
a_k_m_grid_desc
,
a_k_m_grid_desc
,
b_k_n_grid_desc
,
b_k_n_grid_desc
,
a_k_m0_m1_grid_desc
,
b_k_n0_n1_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_block_cluster_adaptor
);
c_
blockid_to_m0_n0_
block_cluster_adaptor
);
}
}
else
else
{
{
const
auto
kernel
=
kernel_dynamic_gemm_v1r2
<
GridwiseGemm
,
const
auto
kernel
=
FloatAB
,
kernel_dynamic_gemm_v1r2
<
GridwiseGemm
,
FloatC
,
FloatAB
,
remove_reference_t
<
AKMGridDesc
>
,
FloatC
,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
AKMGridDesc
>
,
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
BKNGridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
remove_reference_t
<
AKM0M1GridDesc
>
,
false
,
remove_reference_t
<
BKN0N1GridDesc
>
,
false
>
;
remove_reference_t
<
CM0M10M11N0N10N11GridDesc
>
,
remove_reference_t
<
CBlockIdToM0N0BlockClusterAdaptor
>
,
false
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
...
@@ -241,8 +267,10 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
...
@@ -241,8 +267,10 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_c_grid
,
p_c_grid
,
a_k_m_grid_desc
,
a_k_m_grid_desc
,
b_k_n_grid_desc
,
b_k_n_grid_desc
,
a_k_m0_m1_grid_desc
,
b_k_n0_n1_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_block_cluster_adaptor
);
c_
blockid_to_m0_n0_
block_cluster_adaptor
);
}
}
return
ave_time
;
return
ave_time
;
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_v2r2.hpp
View file @
b4d598bd
...
@@ -140,7 +140,8 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
...
@@ -140,7 +140,8 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
public:
public:
__device__
BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
()
__device__
BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
()
:
c_thread_origin_data_idx_
{
CalculateCM0M1N0N1ThreadOriginIndex
(
get_thread_local_1d_id
())},
:
c_thread_origin_data_idx_
{
CalculateCM0M1N0N1ThreadOriginOnBlock
(
get_thread_local_1d_id
())},
a_thread_copy_
{
a_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I0
],
c_thread_origin_data_idx_
[
I1
])},
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I0
],
c_thread_origin_data_idx_
[
I1
])},
b_thread_copy_
{
b_thread_copy_
{
...
@@ -161,14 +162,14 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
...
@@ -161,14 +162,14 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
static_assert
(
M0
==
2
&&
N0
==
2
,
"wrong"
);
static_assert
(
M0
==
2
&&
N0
==
2
,
"wrong"
);
}
}
__device__
static
CIndex
CalculateCM0M1N0N1ThreadOrigin
Index
(
index_t
thread_id
)
__device__
static
CIndex
CalculateCM0M1N0N1ThreadOrigin
OnBlock
(
index_t
thread_id
)
{
{
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
// lower: [M0, M1, N0, N1]
// lower: [M0, M1, N0, N1]
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
constexpr
auto
adaptor0
=
MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor
();
constexpr
auto
adaptor0
=
MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor
();
// upper: [Tid, M0, M11, N0, N11]
// lower: [M0, M100, M101, M11, N0, N100, N101, N11]
// lower: [M0, M100, M101, M11, N0, N100, N101, N11]
// upper: [Tid, M0, M11, N0, N11]
constexpr
auto
adaptor1
=
make_single_stage_tensor_adaptor
(
constexpr
auto
adaptor1
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M100
,
N100
,
M101
,
N101
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
M100
,
N100
,
M101
,
N101
)),
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
M0
),
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
View file @
b4d598bd
...
@@ -17,21 +17,26 @@ template <typename GridwiseGemm,
...
@@ -17,21 +17,26 @@ template <typename GridwiseGemm,
typename
FloatC
,
typename
FloatC
,
typename
AKMGridDesc
,
typename
AKMGridDesc
,
typename
BKNGridDesc
,
typename
BKNGridDesc
,
typename
AKM0M1GridDesc
,
typename
BKN0N1GridDesc
,
typename
CM0M10M11N0N10N11GridDesc
,
typename
CM0M10M11N0N10N11GridDesc
,
typename
CBlockClusterAdaptor
,
typename
CBlock
IdToM0N0Block
ClusterAdaptor
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
bool
HasDoubleTailKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_dynamic_gemm_v1r2
(
const
FloatAB
*
__restrict__
p_a_grid
,
kernel_dynamic_gemm_v1r2
(
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
AKMGridDesc
a_k_m_grid_desc
,
FloatC
*
__restrict__
p_c_grid
,
const
BKNGridDesc
b_k_n_grid_desc
,
const
AKMGridDesc
a_k_m_grid_desc
,
const
CM0M10M11N0N10N11GridDesc
c_m0_m10_m11_n0_n10_n11_grid_desc
,
const
BKNGridDesc
b_k_n_grid_desc
,
const
CBlockClusterAdaptor
c_block_cluster_desc
)
const
AKM0M1GridDesc
a_k_m0_m1_grid_desc
,
const
BKN0N1GridDesc
b_k_n0_n1_grid_desc
,
const
CM0M10M11N0N10N11GridDesc
c_m0_m10_m11_n0_n10_n11_grid_desc
,
const
CBlockIdToM0N0BlockClusterAdaptor
c_blockid_to_m0_n0_block_cluster_adaptor
)
{
{
constexpr
index_t
shared_block_size
=
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
...
@@ -44,8 +49,10 @@ __global__ void
...
@@ -44,8 +49,10 @@ __global__ void
p_shared_block
,
p_shared_block
,
a_k_m_grid_desc
,
a_k_m_grid_desc
,
b_k_n_grid_desc
,
b_k_n_grid_desc
,
a_k_m0_m1_grid_desc
,
b_k_n0_n1_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_block_cluster_
desc
,
c_
blockid_to_m0_n0_
block_cluster_
adaptor
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
}
...
@@ -227,7 +234,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
...
@@ -227,7 +234,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeCBlockClusterAdaptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
MakeCBlock
IdToM0N0Block
ClusterAdaptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
{
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
...
@@ -238,27 +245,32 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
...
@@ -238,27 +245,32 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
const
auto
M0
=
M
/
M1
;
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
N0
=
N
/
N1
;
const
auto
c_block_cluster_adaptor
=
make_cluster_descriptor_v2
(
make_tuple
(
M0
,
N0
));
const
auto
c_blockid_to_m0_n0_block_cluster_adaptor
=
make_cluster_descriptor_v2
(
make_tuple
(
M0
,
N0
));
return
c_block_cluster_adaptor
;
return
c_
blockid_to_m0_n0_
block_cluster_adaptor
;
}
}
using
AKM0M1GridDesc
=
decltype
(
MakeAKM0M1GridDescriptor
(
AKMGridDesc
{}));
using
AKM0M1GridDesc
=
decltype
(
MakeAKM0M1GridDescriptor
(
AKMGridDesc
{}));
using
BKN0N1GridDesc
=
decltype
(
MakeBKN0N1GridDescriptor
(
BKNGridDesc
{}));
using
BKN0N1GridDesc
=
decltype
(
MakeBKN0N1GridDescriptor
(
BKNGridDesc
{}));
using
CM0M10M11N0N10N11GridDesc
=
decltype
(
MakeCM0M10M11N0N10N11GridDescriptor
(
CMNGridDesc
{}));
using
CM0M10M11N0N10N11GridDesc
=
decltype
(
MakeCM0M10M11N0N10N11GridDescriptor
(
CMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CMNGridDesc
{}));
using
CBlockIdToM0N0BlockClusterAdaptor
=
decltype
(
MakeCBlockIdToM0N0BlockClusterAdaptor
(
CMNGridDesc
{}));
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
__device__
static
void
const
FloatAB
*
__restrict__
p_b_grid
,
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatAB
*
__restrict__
p_shared_block
,
FloatC
*
__restrict__
p_c_grid
,
const
AKMGridDesc
&
a_k_m_grid_desc
,
FloatAB
*
__restrict__
p_shared_block
,
const
BKNGridDesc
&
b_k_n_grid_desc
,
const
AKMGridDesc
&
a_k_m_grid_desc
,
const
CM0M10M11N0N10N11GridDesc
&
c_m0_m10_m11_n0_n10_n11_grid_desc
,
const
BKNGridDesc
&
b_k_n_grid_desc
,
const
CBlockClusterAdaptor
&
c_block_cluster_desc
,
const
AKM0M1GridDesc
&
a_k_m0_m1_grid_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
const
BKN0N1GridDesc
&
b_k_n0_n1_grid_desc
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
CM0M10M11N0N10N11GridDesc
&
c_m0_m10_m11_n0_n10_n11_grid_desc
,
const
CBlockIdToM0N0BlockClusterAdaptor
&
c_blockid_to_m0_n0_block_cluster_adaptor
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_a_grid
,
a_k_m_grid_desc
.
GetElementSpaceSize
());
p_a_grid
,
a_k_m_grid_desc
.
GetElementSpaceSize
());
...
@@ -271,15 +283,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
...
@@ -271,15 +283,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
const
auto
M
=
a_k_m_grid_desc
.
GetLength
(
I1
);
const
auto
M
=
a_k_m_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k_n_grid_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
// divide block work by [M, N]
// divide block work by [M, N]
const
auto
block_work_idx
=
const
auto
block_work_idx
=
c_blockid_to_m0_n0_block_cluster_adaptor
.
CalculateBottomIndex
(
c_block_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
make_multi_index
(
get_block_1d_id
()));
// HACK: this force index data into SGPR
// HACK: this force index data into SGPR
const
index_t
m_block_work_idx
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]);
const
index_t
m_block_work_idx
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]);
...
@@ -568,7 +574,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
...
@@ -568,7 +574,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
Number
<
c_m10_n10_m11_n11_thread_tensor_lengths
[
I3
]
>
{}));
Number
<
c_m10_n10_m11_n11_thread_tensor_lengths
[
I3
]
>
{}));
const
auto
c_m10_m11_n10_n11_thread_origin_idx_on_block
=
const
auto
c_m10_m11_n10_n11_thread_origin_idx_on_block
=
blockwise_gemm
.
CalculateCM0M1N0N1ThreadOrigin
Index
(
get_thread_local_1d_id
());
blockwise_gemm
.
CalculateCM0M1N0N1ThreadOrigin
OnBlock
(
get_thread_local_1d_id
());
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatAcc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment