Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
78f637e4
Unverified
Commit
78f637e4
authored
Apr 08, 2024
by
zjing14
Committed by
GitHub
Apr 08, 2024
Browse files
Merge pull request #58 from ROCm/navi4x_conv_fwd
Navi4x Conv and MHA enablement
parents
7e147c64
5cb59d36
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
66 additions
and
78 deletions
+66
-78
example/20_grouped_conv_bwd_weight/CMakeLists.txt
example/20_grouped_conv_bwd_weight/CMakeLists.txt
+1
-1
example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt
example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt
+1
-1
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+1
-1
example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt
example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt
+1
-1
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+34
-48
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
...l/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+4
-3
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
...device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
...ice/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp
...vice/impl/device_grouped_query_attention_forward_wmma.hpp
+3
-2
include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp
...device/impl/device_multi_query_attention_forward_wmma.hpp
+4
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+5
-7
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+6
-6
No files found.
example/20_grouped_conv_bwd_weight/CMakeLists.txt
View file @
78f637e4
list
(
APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942 gfx950
)
list
(
APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102
)
list
(
APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102
gfx1200
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list_xdl AND target EQUAL 0
)
...
...
example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt
View file @
78f637e4
list
(
APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950
)
list
(
APPEND gpu_list2 gfx1100 gfx1101 gfx1102 gfx1103
)
list
(
APPEND gpu_list2 gfx1100 gfx1101 gfx1102 gfx1103
gfx1200
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
...
...
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
78f637e4
if
(
GPU_TARGETS MATCHES
"gfx11"
)
if
(
GPU_TARGETS MATCHES
"gfx11"
OR GPU_TARGETS MATCHES
"gfx12"
)
add_example_executable
(
example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
)
add_example_executable
(
example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp
)
...
...
example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt
View file @
78f637e4
list
(
APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942 gfx950
)
list
(
APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102
)
list
(
APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102
gfx1200
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list_xdl AND target EQUAL 0
)
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
78f637e4
...
...
@@ -70,6 +70,9 @@ struct BlockwiseGemmWMMA
static
constexpr
index_t
A_KRow
=
2
;
static
constexpr
index_t
B_KRow
=
2
;
static
constexpr
index_t
A_KRow_
=
AEnableLds
?
1
:
2
;
static
constexpr
index_t
B_KRow_
=
BEnableLds
?
1
:
2
;
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I5
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I5
);
...
...
@@ -191,9 +194,6 @@ struct BlockwiseGemmWMMA
static_assert
(
MPerBlock
%
(
MPerWMMA
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
"wrong!"
);
static_assert
(
AEnableLds
==
true
,
"only support EnableLds"
);
static_assert
(
BEnableLds
==
true
,
"only support EnableLds"
);
}
// transposed WMMA output C' = B' * A'
...
...
@@ -316,7 +316,7 @@ struct BlockwiseGemmWMMA
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
A_K1
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
KPack
/
A_K1
/
A_KRow_
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
...
...
@@ -326,7 +326,8 @@ struct BlockwiseGemmWMMA
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
B_K1
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
KPack
/
B_K1
/
B_KRow_
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
...
...
@@ -372,7 +373,7 @@ struct BlockwiseGemmWMMA
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
B_K1
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
KPack
/
B_K1
/
B_KRow_
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
...
...
@@ -380,7 +381,7 @@ struct BlockwiseGemmWMMA
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
A_K1
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
KPack
/
A_K1
/
A_KRow_
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
...
...
@@ -442,44 +443,30 @@ struct BlockwiseGemmWMMA
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
wmma_gemm
.
GetRegSizePerWmma
()));
template
<
bool
EnableLds
>
struct
AThreadCopySelector
;
template
<
>
struct
AThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
Sequence
<
KPack
/
A_K1
/
A_KRow
,
1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
A_K1
,
A_K1
>
;
};
template
<
bool
EnableLds
>
struct
BThreadCopySelector
;
template
<
>
struct
BThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
Sequence
<
KPack
/
B_K1
/
B_KRow
,
1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
B_K1
,
B_K1
>
;
};
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
typename
BThreadCopySelector
<
BEnableLds
>::
type
b_thread_copy_
;
using
AThreadCopyType
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
Sequence
<
KPack
/
A_K1
/
A_KRow
,
1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
A_K1
,
A_K1
>
;
using
BThreadCopyType
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
Sequence
<
KPack
/
B_K1
/
B_KRow
,
1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
B_K1
,
B_K1
>
;
AThreadCopyType
a_thread_copy_
;
BThreadCopyType
b_thread_copy_
;
};
#else
template
<
index_t
BlockSize
,
...
...
@@ -537,9 +524,8 @@ struct BlockwiseGemmWMMA
// permutation
static
constexpr
index_t
A_KRow
=
AEnableLds
?
1
:
2
;
static
constexpr
index_t
B_KRow
=
BEnableLds
?
1
:
2
;
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I5
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I5
);
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I5
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I5
);
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
>
{};
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
View file @
78f637e4
...
...
@@ -829,7 +829,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
is_navi3_supported
())
if
(
ck
::
is_navi3_supported
()
||
ck
::
is_navi4_supported
()
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
78f637e4
...
...
@@ -56,7 +56,7 @@ __global__ void
bool
input_permute
,
bool
output_permute
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// clang-format off
// ***************************************************
...
...
@@ -159,6 +159,7 @@ __global__ void
ignore
=
O
;
ignore
=
G0
;
ignore
=
G1
;
ignore
=
alpha
;
ignore
=
input_permute
;
ignore
=
output_permute
;
#endif // end of if (defined(__gfx11__))
...
...
@@ -187,7 +188,7 @@ __global__ void
index_t
head_size
,
float
alpha
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// clang-format off
// ***************************************************
...
...
@@ -858,7 +859,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
{
if
(
ck
::
is_navi3_supported
())
if
(
ck
::
is_navi3_supported
()
||
ck
::
is_navi4_supported
()
)
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
78f637e4
...
...
@@ -94,8 +94,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
(
MWaves
==
1
&&
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
?
false
:
true
;
// If true, LDS is used unconditionally
static
constexpr
auto
AEnableLds_manu
=
tru
e
;
static
constexpr
auto
BEnableLds_manu
=
tru
e
;
static
constexpr
auto
AEnableLds_manu
=
fals
e
;
static
constexpr
auto
BEnableLds_manu
=
fals
e
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
View file @
78f637e4
...
...
@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
// check device
if
(
ck
::
is_navi3_supported
())
if
(
ck
::
is_navi3_supported
()
||
ck
::
is_navi4_supported
()
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
View file @
78f637e4
...
...
@@ -702,7 +702,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
// check device
if
(
ck
::
is_navi3_supported
())
if
(
ck
::
is_navi3_supported
()
||
ck
::
is_navi4_supported
()
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
78f637e4
...
...
@@ -581,7 +581,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
namespace
ctc
=
tensor_layout
::
convolution
;
// check device
if
(
ck
::
is_navi3_supported
())
if
(
ck
::
is_navi3_supported
()
||
ck
::
is_navi4_supported
()
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp
View file @
78f637e4
...
...
@@ -61,7 +61,7 @@ __global__ void
bool
input_permute
,
bool
output_permute
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// clang-format off
// ***************************************************
...
...
@@ -166,6 +166,7 @@ __global__ void
ignore
=
O
;
ignore
=
G0
;
ignore
=
G1
;
ignore
=
alpha
;
ignore
=
input_permute
;
ignore
=
output_permute
;
#endif // end of if (defined(__gfx11__))
...
...
@@ -596,7 +597,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
{
if
(
ck
::
is_navi3_supported
())
if
(
ck
::
is_navi3_supported
()
||
ck
::
is_navi4_supported
()
)
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp
View file @
78f637e4
...
...
@@ -60,7 +60,8 @@ __global__ void
bool
input_permute
,
bool
output_permute
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx11__) || \
defined(__gfx12__))
// clang-format off
// ***************************************************
...
...
@@ -165,6 +166,7 @@ __global__ void
ignore
=
O
;
ignore
=
G0
;
ignore
=
G1
;
ignore
=
alpha
;
ignore
=
input_permute
;
ignore
=
output_permute
;
#endif // end of if (defined(__gfx11__))
...
...
@@ -594,7 +596,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
{
if
(
ck
::
is_navi3_supported
())
if
(
ck
::
is_navi3_supported
()
||
ck
::
is_navi4_supported
()
)
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
78f637e4
...
...
@@ -571,15 +571,12 @@ struct GridwiseGemmMultipleD_Wmma
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
constexpr
auto
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMRepeatPerShuffle
*
MWave
*
MPerWmma
>
{},
Number
<
CShuffleMRepeatPerShuffle
*
MWave
s
*
MPerWmma
>
{},
I1
,
Number
<
CShuffleNRepeatPerShuffle
*
NWave
*
NPerWmma
>
{}));
Number
<
CShuffleNRepeatPerShuffle
*
NWave
s
*
NPerWmma
>
{}));
return
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
;
}
...
...
@@ -799,8 +796,9 @@ struct GridwiseGemmMultipleD_Wmma
const
auto
M
=
e_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
e_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
e_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
78f637e4
...
...
@@ -522,12 +522,6 @@ struct GridwiseGemm_Wmma
c_grid_desc_m_n
);
}
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
struct
SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
...
...
@@ -559,6 +553,12 @@ struct GridwiseGemm_Wmma
b_block_space_size_aligned
*
sizeof
(
BDataType
));
};
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment