Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
04c6a978
"include/ck/utility/reduction_functions_accumulate.hpp" did not exist on "6fe3627a9eb35f1237266f1b6cc8fd3456aed67d"
Commit
04c6a978
authored
Mar 06, 2023
by
aska-0096
Browse files
Skip B-Lds real gemm
parent
f00dab9f
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
301 additions
and
123 deletions
+301
-123
example/01_gemm/gemm_wmma_fp16.cpp
example/01_gemm/gemm_wmma_fp16.cpp
+3
-3
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+47
-23
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+95
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+156
-96
No files found.
example/01_gemm/gemm_wmma_fp16.cpp
View file @
04c6a978
...
@@ -42,8 +42,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
...
@@ -42,8 +42,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
8
,
// K1
8
,
// K1
16
,
// MPerWmma
16
,
// MPerWmma
16
,
// NPerWmma
16
,
// NPerWmma
2
,
// M Repeat
8
,
// M Repeat
4
,
// N-Repeat
1
,
// N-Repeat
S
<
4
,
64
,
1
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
@@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
...
@@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
true
,
true
,
1
,
// C shuffle (M Repeat) Per store
1
,
// C shuffle (M Repeat) Per store
1
,
// C shuffle (N Repeat) Per store
1
,
// C shuffle (N Repeat) Per store
S
<
1
,
6
4
,
1
,
4
>
,
S
<
1
,
1
6
,
1
,
16
>
,
8
>
;
8
>
;
// clang-format on
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
04c6a978
...
@@ -106,12 +106,13 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -106,12 +106,13 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
}
#ifdef ENABLE_COLMAJOR
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
const
auto
a_grid_desc_mraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
I1
,
StrideA
));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
}
#endif
}();
}();
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
...
@@ -146,26 +147,33 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -146,26 +147,33 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
}
}
}
}
static
auto
MakeBGridDescriptor
_K0_N_K1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
static
auto
MakeBGridDescriptor
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
{
const
auto
b_grid_desc_n
raw_kraw
=
[
&
]()
{
const
auto
b_grid_desc_n
_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
const
auto
b_grid_desc_nraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
make_tuple
(
I1
,
StrideB
));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
const
auto
b_grid_desc_nraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
make_tuple
(
StrideB
,
I1
));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
}
}();
}();
const
auto
b_grid_desc_n_k
=
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
assert
(
K
%
K1
==
0
);
assert
(
K
%
K1
==
0
);
if
constexpr
(
BEnableLds
)
{
const
index_t
K0
=
K
/
K1
;
const
index_t
K0
=
K
/
K1
;
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
...
@@ -175,6 +183,22 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -175,6 +183,22 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
else
{
constexpr
auto
B_KRow
=
WmmaK
/
K1
;
const
auto
B_KWmma
=
K
/
WmmaK
;
const
auto
N0
=
N
/
NPerBlock
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B_KWmma
,
Number
<
B_KRow
>
{},
K1Number
)),
make_unmerge_transform
(
make_tuple
(
N0
*
NRepeat
,
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
}
}
static
auto
MakeCGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideC
)
static
auto
MakeCGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideC
)
{
{
...
@@ -196,7 +220,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -196,7 +220,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
// Gridwise descriptor, mapping to whole given provblem.
// Gridwise descriptor, mapping to whole given provblem.
using
AGridDesc
=
decltype
(
MakeAGridDescriptor
(
1
,
1
,
1
));
using
AGridDesc
=
decltype
(
MakeAGridDescriptor
(
1
,
1
,
1
));
using
BGridDesc
_K0_N_K1
=
decltype
(
MakeBGridDescriptor
_K0_N_K1
(
1
,
1
,
1
));
using
BGridDesc
=
decltype
(
MakeBGridDescriptor
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
// GridwiseGemm
// GridwiseGemm
...
@@ -209,7 +233,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -209,7 +233,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc
,
AGridDesc
,
BGridDesc
_K0_N_K1
,
BGridDesc
,
CGridDesc_M_N
,
CGridDesc_M_N
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
...
@@ -281,7 +305,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -281,7 +305,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
{
{
a_grid_desc_
=
DeviceGemmWmma_CShuffle
::
MakeAGridDescriptor
(
M
,
K
,
StrideA
);
a_grid_desc_
=
DeviceGemmWmma_CShuffle
::
MakeAGridDescriptor
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
=
b_grid_desc_k0_n_k1_
=
DeviceGemmWmma_CShuffle
::
MakeBGridDescriptor
_K0_N_K1
(
K
,
N
,
StrideB
);
DeviceGemmWmma_CShuffle
::
MakeBGridDescriptor
(
K
,
N
,
StrideB
);
c_grid_desc_m_n_
=
DeviceGemmWmma_CShuffle
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
c_grid_desc_m_n_
=
DeviceGemmWmma_CShuffle
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
block_2_ctile_map_
=
block_2_ctile_map_
=
...
@@ -301,7 +325,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -301,7 +325,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc
a_grid_desc_
;
AGridDesc
a_grid_desc_
;
BGridDesc
_K0_N_K1
b_grid_desc_k0_n_k1_
;
BGridDesc
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
;
c_grid_desc_mblock_mperblock_nblock_nperblock
;
...
@@ -371,7 +395,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -371,7 +395,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
AGridDesc
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
AGridDesc
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
BGridDesc
_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
BGridDesc
>
,
remove_reference_t
<
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
AElementwiseOperation
,
AElementwiseOperation
,
...
@@ -404,7 +428,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -404,7 +428,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
AGridDesc
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
AGridDesc
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
BGridDesc
_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
BGridDesc
>
,
remove_reference_t
<
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
AElementwiseOperation
,
AElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
04c6a978
...
@@ -309,9 +309,9 @@ struct GridwiseGemmPipeline_v1<1, false, true>
...
@@ -309,9 +309,9 @@ struct GridwiseGemmPipeline_v1<1, false, true>
auto
a_block_buf_switch
=
a_block_buf
;
auto
a_block_buf_switch
=
a_block_buf
;
// preload data into LDS
// preload data into LDS
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
Run
(
a_blockwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_origin_idx
,
a_block_buf
);
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_origin_idx
,
a_block_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
...
@@ -364,6 +364,100 @@ struct GridwiseGemmPipeline_v1<1, false, true>
...
@@ -364,6 +364,100 @@ struct GridwiseGemmPipeline_v1<1, false, true>
template
<
>
template
<
>
struct
GridwiseGemmPipeline_v1
<
1
,
true
,
false
>
struct
GridwiseGemmPipeline_v1
<
1
,
true
,
false
>
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
constexpr
auto
b_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
auto
b_block_buf_switch
=
b_block_buf
;
// preload data into LDS
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_origin_idx
,
b_block_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_origin_idx
,
b_block_buf_switch
);
block_sync_lds
();
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_block_buf
=
b_block_buf_switch
;
++
i
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
{
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
}
}
};
};
template
<
>
template
<
>
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
04c6a978
...
@@ -18,11 +18,11 @@
...
@@ -18,11 +18,11 @@
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatA
,
typename
ADataType
,
typename
FloatB
,
typename
BDataType
,
typename
FloatC
,
typename
CDataType
,
typename
AGridDesc
,
typename
AGridDesc
,
typename
BGridDesc
_K0_N_K1
,
typename
BGridDesc
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
...
@@ -33,11 +33,11 @@ __global__ void
...
@@ -33,11 +33,11 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_wmma
(
const
FloatA
*
__restrict__
p_a_grid
,
kernel_gemm_wmma
(
const
ADataType
*
__restrict__
p_a_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
CDataType
*
__restrict__
p_c_grid
,
const
AGridDesc
a_grid_desc
,
const
AGridDesc
a_grid_desc
,
const
BGridDesc
_K0_N_K1
b_grid_desc
_k0_n_k1
,
const
BGridDesc
b_grid_desc
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
...
@@ -53,7 +53,7 @@ __global__ void
...
@@ -53,7 +53,7 @@ __global__ void
p_c_grid
,
p_c_grid
,
p_shared
,
p_shared
,
a_grid_desc
,
a_grid_desc
,
b_grid_desc
_k0_n_k1
,
b_grid_desc
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -64,7 +64,7 @@ __global__ void
...
@@ -64,7 +64,7 @@ __global__ void
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
p_c_grid
;
ignore
=
a_grid_desc
;
ignore
=
a_grid_desc
;
ignore
=
b_grid_desc
_k0_n_k1
;
ignore
=
b_grid_desc
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
...
@@ -74,14 +74,14 @@ __global__ void
...
@@ -74,14 +74,14 @@ __global__ void
}
}
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
ADataType
,
typename
FloatB
,
typename
BDataType
,
typename
FloatAcc
,
typename
AccDataType
,
typename
Float
CShuffle
,
typename
CShuffle
DataType
,
typename
FloatC
,
typename
CDataType
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc
,
typename
AGridDesc
,
typename
BGridDesc
_K0_N_K1
,
typename
BGridDesc
,
typename
CGridDesc_M_N
,
typename
CGridDesc_M_N
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
...
@@ -181,6 +181,40 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -181,6 +181,40 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
return
a_block_desc
;
return
a_block_desc
;
}
}
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor
()
{
constexpr
auto
b_block_desc
=
[
&
]()
{
if
constexpr
(
BEnableLds
)
{
// K0->N->K1 Per Block
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
constexpr
auto
max_lds_align
=
K1
;
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}
else
{
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
// KWmma->NRepeat->NWave->NRow->NPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
I1
,
K1
),
make_tuple
(
Number
<
NRepeat
>
{}
*
K1
,
K1
,
K1
,
K1
,
K1
,
I1
));
}
}();
return
b_block_desc
;
}
__host__
__device__
static
constexpr
auto
MakeABlockSliceCopyStep
()
__host__
__device__
static
constexpr
auto
MakeABlockSliceCopyStep
()
{
{
constexpr
auto
a_block_copy_step
=
[
&
]()
{
constexpr
auto
a_block_copy_step
=
[
&
]()
{
...
@@ -292,43 +326,56 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -292,43 +326,56 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
return
a_wave_desc
;
return
a_wave_desc
;
}
}
template
<
typename
BBlockDesc_
BK0_N_BK1
>
template
<
typename
BBlockDesc_
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeB
Block
Descriptor
_K0_N0_N1_N2_K1
(
const
BBlockDesc_
BK0_N_BK1
&
)
MakeB
Wave
Descriptor
(
const
BBlockDesc_
&
)
{
{
constexpr
auto
B_K0
=
BBlockDesc_BK0_N_BK1
{}.
GetLength
(
I0
);
constexpr
auto
b_wave_desc
=
[
&
]()
{
constexpr
auto
B_K1
=
BBlockDesc_BK0_N_BK1
{}.
GetLength
(
I2
);
if
constexpr
(
BEnableLds
)
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
BBlockDesc_
BK0_N_BK1
{},
BBlockDesc_
{},
make_tuple
(
make_pass_through_transform
(
Number
<
B_K0
>
{}),
make_tuple
(
make_pass_through_transform
(
Number
<
B_K0
>
{}),
make_unmerge_transform
(
make_unmerge_transform
(
make_tuple
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
else
{
{
return
make_naive_tensor_descriptor_aligned
(
// KWmma_NRepeat_NWave_KRow_NPerWmma_K1 -> K0_NRepeat_Nwaves_NPerWmma_K1
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
constexpr
auto
KWmma
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I5
);
// Workaround, Freeze transform
return
transform_tensor_descriptor
(
BBlockDesc_
{},
make_tuple
(
make_freeze_transform
(
I0
),
make_pass_through_transform
(
Number
<
KWmma
>
{}),
make_pass_through_transform
(
Number
<
NRepeat
>
{}),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
}
}
}();
}();
return
b_
block_desc_k0perblock_nperblock_k1
;
return
b_
wave_desc
;
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
...
@@ -349,7 +396,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -349,7 +396,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
template
<
typename
Block2CTileMap
>
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc
_K0_N_K1
&
b_grid_desc
_k0_n_k1
,
const
BGridDesc
&
b_grid_desc
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
{
...
@@ -378,17 +425,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -378,17 +425,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
const
auto
GetBProblemsizeNK
=
[
&
]()
{
const
auto
GetBProblemsizeNK
=
[
&
]()
{
if
constexpr
(
BEnableLds
)
if
constexpr
(
BEnableLds
)
{
{
return
make_tuple
(
b_grid_desc
_k0_n_k1
.
GetLength
(
I1
),
return
make_tuple
(
b_grid_desc
.
GetLength
(
I1
),
b_grid_desc
_k0_n_k1
.
GetLength
(
I0
)
*
b_grid_desc
.
GetLength
(
I0
)
*
b_grid_desc
_k0_n_k1
.
GetLength
(
I2
));
b_grid_desc
.
GetLength
(
I2
));
}
}
else
else
{
{
return
make_tuple
(
return
make_tuple
(
b_grid_desc
_k0_n_k1
.
GetLength
(
I1
)
*
b_grid_desc
_k0_n_k1
.
GetLength
(
I2
)
*
b_grid_desc
.
GetLength
(
I1
)
*
b_grid_desc
.
GetLength
(
I2
)
*
b_grid_desc
_k0_n_k1
.
GetLength
(
I4
),
b_grid_desc
.
GetLength
(
I4
),
b_grid_desc
_k0_n_k1
.
GetLength
(
I0
)
*
b_grid_desc
_k0_n_k1
.
GetLength
(
I3
)
*
b_grid_desc
.
GetLength
(
I0
)
*
b_grid_desc
.
GetLength
(
I3
)
*
b_grid_desc
_k0_n_k1
.
GetLength
(
I5
));
b_grid_desc
.
GetLength
(
I5
));
}
}
};
};
...
@@ -484,8 +531,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -484,8 +531,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
max_lds_align
)
max_lds_align
)
:
0
;
:
0
;
static
constexpr
auto
b_block_space_size_aligned
=
static
constexpr
auto
b_block_space_size_aligned
=
BEnableLds
?
math
::
integer_least_multiple
(
BEnableLds
?
math
::
integer_least_multiple
(
MakeBBlockDescriptor
().
GetElementSpaceSize
(),
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
().
GetElementSpaceSize
(),
max_lds_align
)
max_lds_align
)
:
0
;
:
0
;
...
@@ -500,18 +546,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -500,18 +546,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
static
constexpr
auto
c_shuffle_block_space_offset
=
0
;
static
constexpr
auto
c_shuffle_block_space_offset
=
0
;
static
constexpr
auto
lds_size
=
static
constexpr
auto
lds_size
=
math
::
max
(
c_shuffle_block_space_size
*
sizeof
(
Float
CShuffle
),
math
::
max
(
c_shuffle_block_space_size
*
sizeof
(
CShuffle
DataType
),
a_block_space_size_aligned
*
sizeof
(
FloatA
)
+
a_block_space_size_aligned
*
sizeof
(
ADataType
)
+
b_block_space_size_aligned
*
sizeof
(
FloatB
));
b_block_space_size_aligned
*
sizeof
(
BDataType
));
};
};
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
CDataType
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AGridDesc
&
a_grid_desc
,
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc
_K0_N_K1
&
b_grid_desc
_k0_n_k1
,
const
BGridDesc
&
b_grid_desc
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
...
@@ -525,7 +571,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -525,7 +571,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc
_k0_n_k1
.
GetElementSpaceSize
());
p_b_grid
,
b_grid_desc
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
...
@@ -554,7 +600,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -554,7 +600,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
}();
}();
constexpr
auto
a_block_desc
=
MakeABlockDescriptor
();
constexpr
auto
a_block_desc
=
MakeABlockDescriptor
();
constexpr
auto
b_block_desc
=
Get
BBlockDescriptor
_K0PerBlock_NPerBlock_K1
();
constexpr
auto
b_block_desc
=
Make
BBlockDescriptor
();
auto
a_block_trait
=
[
&
](){
auto
a_block_trait
=
[
&
](){
// A matrix blockwise copy
// A matrix blockwise copy
...
@@ -562,7 +608,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -562,7 +608,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
{
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatA
*>
(
p_shared
),
static_cast
<
ADataType
*>
(
p_shared
),
SharedMemTrait
::
a_block_space_size_aligned
);
SharedMemTrait
::
a_block_space_size_aligned
);
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
...
@@ -573,8 +619,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -573,8 +619,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/* typename BlockSliceLengths, */
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
/* typename BlockSliceLengths, */
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
/* typename ThreadClusterLengths, */
ABlockTransferThreadClusterLengths_K0_M_K1
,
/* typename ThreadClusterLengths, */
ABlockTransferThreadClusterLengths_K0_M_K1
,
/* typename ThreadClusterArrangeOrder, */
ABlockTransferThreadClusterArrangeOrder
,
/* typename ThreadClusterArrangeOrder, */
ABlockTransferThreadClusterArrangeOrder
,
/* typename SrcData, */
FloatA
,
/* typename SrcData, */
ADataType
,
/* typename DstData, */
FloatA
,
/* typename DstData, */
ADataType
,
/* typename SrcDesc, */
decltype
(
a_grid_desc
),
/* typename SrcDesc, */
decltype
(
a_grid_desc
),
/* typename DstDesc, */
decltype
(
a_block_desc
),
/* typename DstDesc, */
decltype
(
a_block_desc
),
/* typename SrcDimAccessOrder, */
ABlockTransferSrcAccessOrder
,
/* typename SrcDimAccessOrder, */
ABlockTransferSrcAccessOrder
,
...
@@ -601,13 +647,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -601,13 +647,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
// Thread-wise copy
// Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
auto
a_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
auto
a_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
>
(
a_block_desc
.
GetElementSpaceSize
());
a_block_desc
.
GetElementSpaceSize
());
// Limitation: NumDim of Src and Dst descriptor should be identical
// Limitation: NumDim of Src and Dst descriptor should be identical
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
FloatA
,
ThreadwiseTensorSliceTransfer_v2
<
ADataType
,
FloatA
,
ADataType
,
decltype
(
a_grid_desc
),
decltype
(
a_grid_desc
),
decltype
(
a_block_desc
),
decltype
(
a_block_desc
),
Sequence
<
Number
<
KWmmaPerBlock
>
{},
Sequence
<
Number
<
KWmmaPerBlock
>
{},
...
@@ -638,7 +684,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -638,7 +684,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
{
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatB
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
static_cast
<
BDataType
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
SharedMemTrait
::
b_block_space_size_aligned
);
SharedMemTrait
::
b_block_space_size_aligned
);
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
...
@@ -649,9 +695,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -649,9 +695,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatB
,
BDataType
,
FloatB
,
BDataType
,
decltype
(
b_grid_desc
_k0_n_k1
),
decltype
(
b_grid_desc
),
decltype
(
b_block_desc
),
decltype
(
b_block_desc
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
...
@@ -663,7 +709,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -663,7 +709,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
b_grid_desc
_k0_n_k1
,
b_grid_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_element_op
,
b_block_desc
,
b_block_desc
,
...
@@ -674,22 +720,36 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -674,22 +720,36 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
}
}
else
else
{
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
// Thread-wise copy
auto
b_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatB
>
(
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
auto
b_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
>
(
b_block_desc
.
GetElementSpaceSize
());
b_block_desc
.
GetElementSpaceSize
());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
ThreadwiseTensorSliceTransfer_v
4
<
FloatB
,
ThreadwiseTensorSliceTransfer_v
2
<
BDataType
,
FloatB
,
BDataType
,
decltype
(
b_grid_desc
_k0_n_k1
),
decltype
(
b_grid_desc
),
decltype
(
b_block_desc
),
decltype
(
b_block_desc
),
Sequence
<
Number
<
K
0
PerBlock
>
{},
Sequence
<
Number
<
K
Wmma
PerBlock
>
{},
Number
<
NRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
K1Value
>
{}
>
,
Number
<
K1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
2
,
5
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
1
>
(
BThreadTransferSrcResetCoordinateAfterRun
,
make_multi_index
(
0
,
get_thread_local_1d_id
()
/
32
*
16
+
get_thread_local_1d_id
()
%
16
,
0
));
true
>
(
b_grid_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
/
(
NWaves
*
NPerWmma
),
get_thread_local_1d_id
()
/
32
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
get_thread_local_1d_id
()
%
16
,
0
));
return
make_tuple
(
b_block_buf
,
b_blockwise_copy
);
return
make_tuple
(
b_block_buf
,
b_blockwise_copy
);
}
}
...
@@ -706,11 +766,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -706,11 +766,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
auto
blockwise_gemm
=
auto
blockwise_gemm
=
BlockwiseGemmWMMA
<
BlockSize
,
BlockwiseGemmWMMA
<
BlockSize
,
FloatA
,
ADataType
,
FloatB
,
BDataType
,
FloatAcc
,
AccDataType
,
decltype
(
MakeAWaveDescriptor
(
a_block_desc
)),
decltype
(
MakeAWaveDescriptor
(
a_block_desc
)),
decltype
(
MakeB
Block
Descriptor
_K0_N0_N1_N2_K1
(
b_block_desc
)),
decltype
(
MakeB
Wave
Descriptor
(
b_block_desc
)),
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
KPerBlock
,
KPerBlock
,
...
@@ -738,7 +798,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -738,7 +798,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
a_grid_buf
,
a_grid_buf
,
a_block_buf
,
a_block_buf
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
b_grid_desc
_k0_n_k1
,
b_grid_desc
,
b_block_desc
,
b_block_desc
,
b_blockwise_copy
,
b_blockwise_copy
,
b_grid_buf
,
b_grid_buf
,
...
@@ -768,7 +828,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -768,7 +828,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
();
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
Float
CShuffle
*>
(
p_shared
)
+
SharedMemTrait
::
c_shuffle_block_space_offset
,
static_cast
<
CShuffle
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
c_shuffle_block_space_offset
,
SharedMemTrait
::
c_shuffle_block_space_size
);
SharedMemTrait
::
c_shuffle_block_space_size
);
constexpr
auto
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
transform_tensor_descriptor
(
constexpr
auto
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
transform_tensor_descriptor
(
...
@@ -815,8 +875,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -815,8 +875,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
// shuffle: threadwise copy C from VGPR to LDS
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
Float
CShuffle
,
CShuffle
DataType
,
decltype
(
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
),
decltype
(
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
),
decltype
(
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
),
decltype
(
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -854,8 +914,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -854,8 +914,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
CShuffleNRepeatPerShuffle
*
NWave
*
NPerWmma
>
,
// BlockSliceLengths,
CShuffleNRepeatPerShuffle
*
NWave
*
NPerWmma
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Float
CShuffle
,
// typename SrcData,
CShuffle
DataType
,
// typename SrcData,
FloatC
,
// typename DstData,
CDataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
),
decltype
(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment