Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a38ce024
Commit
a38ce024
authored
Mar 06, 2023
by
aska-0096
Browse files
batched gemm ported
parent
686212eb
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
204 additions
and
176 deletions
+204
-176
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
..._bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
+4
-4
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+8
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
...l/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
+125
-111
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+4
-3
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+15
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+46
-48
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+0
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+2
-2
No files found.
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
View file @
a38ce024
...
...
@@ -70,12 +70,12 @@ using DeviceOpInstanceKKNN =
256
,
128
,
128
,
4
,
32
,
8
,
16
,
16
,
4
,
2
,
1
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
@@ -92,7 +92,7 @@ using DeviceOpInstanceKKNN =
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
1
,
128
,
1
,
2
>
,
8
>
;
using
DeviceOpInstance
=
DeviceOpInstanceKKNN
;
...
...
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
a38ce024
...
...
@@ -5,6 +5,9 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16
add_example_executable
(
example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
if
(
GPU_TARGETS MATCHES
"gfx1100"
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
)
endif
()
add_custom_target
(
example_gemm_scale_softmax_gemm
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16
)
...
...
@@ -14,3 +17,8 @@ add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_soft
add_dependencies
(
example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16
)
if
(
GPU_TARGETS MATCHES
"gfx1100"
)
add_custom_target
(
example_gemm_scale_softmax_gemm_wmma
)
add_dependencies
(
example_gemm_scale_softmax_gemm_wmma example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16
)
endif
()
\ No newline at end of file
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
View file @
a38ce024
...
...
@@ -76,10 +76,10 @@ template <index_t NumDimG,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K
0
PerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
MPerW
MMA
,
ck
::
index_t
NPerW
MMA
,
ck
::
index_t
MPerW
mma
,
ck
::
index_t
NPerW
mma
,
ck
::
index_t
MRepeat
,
ck
::
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
...
...
@@ -123,14 +123,23 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
// K1 = Max Vector Access Pixels
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
WmmaK
=
16
;
static
constexpr
auto
AEnableLds
=
NWaves
==
1
?
false
:
true
;
static
constexpr
auto
BEnableLds
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
K
0
PerBlock
*
K1
};
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
// Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
static
auto
MakeAGridDescriptor
_M_K
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
static
auto
MakeAGridDescriptor
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
{
assert
(
a_gs_ms_ks_lengths_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimK
&&
...
...
@@ -158,6 +167,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
// lengths for K0, K1, ...
const
auto
kLengths
=
get_container_subset
(
a_ms_ks_lengths
,
kDimIds
);
const
auto
a_grid_desc_m_k
=
[
&
](){
if
constexpr
(
ASpec
==
TensorSpecialization
::
Packed
)
{
auto
M
=
container_reduce
(
mLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
...
...
@@ -183,10 +193,42 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
}();
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
assert
(
K
%
K1
==
0
);
if
constexpr
(
AEnableLds
)
{
const
index_t
K0
=
K
/
K1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
constexpr
auto
A_KRow
=
WmmaK
/
K1
;
const
auto
A_KWmma
=
K
/
WmmaK
;
const
auto
M0
=
M
/
MPerBlock
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
A_KWmma
,
Number
<
A_KRow
>
{},
K1Number
)),
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
,
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
}
}
// Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
static
auto
MakeBGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths_vec
,
static
auto
MakeBGridDescriptor_
K0_
N_K
1
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides_vec
)
{
assert
(
b_gs_ns_ks_lengths_vec
.
size
()
==
NumDimG
+
NumDimN
+
NumDimK
&&
...
...
@@ -214,6 +256,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
// lengths for N0, N1, ...
const
auto
nLengths
=
get_container_subset
(
b_ns_ks_lengths
,
nDimIds
);
const
auto
b_grid_desc_n_k
=
[
&
](){
if
constexpr
(
BSpec
==
TensorSpecialization
::
Packed
)
{
auto
N
=
container_reduce
(
nLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
...
...
@@ -239,6 +282,19 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
}();
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
...
...
@@ -393,8 +449,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
}
// Gridwise descriptor, mapping to whole given provblem.
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
({},
{}));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
({},
{}));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
({},
{}));
...
...
@@ -449,42 +503,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
EGridDesc_G_M_N
e_grid_desc_g_m_n_
;
};
// A desc for source in blockwise copy
template
<
typename
AGridDesc_M_K
>
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_K0_M_K1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
{
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
K1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
K1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// B desc for source in blockwise copy
template
<
typename
BGridDesc_N_K
>
__host__
__device__
static
constexpr
auto
MakeBGridDescriptor_K0_N_K1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
{
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
K1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
using
AGridDesc_K0_M_K1
=
decltype
(
DeviceOp
::
MakeAGridDescriptor_K0_M_K1
(
AGridDesc_M_K
{}));
using
BGridDesc_K0_N_K1
=
decltype
(
DeviceOp
::
MakeBGridDescriptor_K0_N_K1
(
BGridDesc_N_K
{}));
using
AGridDesc
=
decltype
(
DeviceOp
::
MakeAGridDescriptor
({},{}));
using
BGridDesc_K0_N_K1
=
decltype
(
DeviceOp
::
MakeBGridDescriptor_K0_N_K1
({},{}));
// GridwiseOp
using
GridwiseOp
=
GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
<
...
...
@@ -496,7 +516,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
DsDataType
,
EDataType
,
// InMemory Data Descriptor
AGridDesc
_K0_M_K1
,
AGridDesc
,
BGridDesc_K0_N_K1
,
DsGridDesc_M_N
,
EGridDesc_M_N
,
...
...
@@ -508,9 +528,9 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
// Tiling Family
MPerBlock
,
NPerBlock
,
K
0
PerBlock
,
MPerW
MMA
,
NPerW
MMA
,
KPerBlock
,
MPerW
mma
,
NPerW
mma
,
K1
,
MRepeat
,
NRepeat
,
...
...
@@ -523,6 +543,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
AEnableLds
,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
...
...
@@ -531,6 +552,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BEnableLds
,
BBlockLdsAddExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
...
...
@@ -564,16 +586,14 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
a_grid_desc_m_k_
{},
b_grid_desc_n_k_
{},
a_grid_desc_
k0_
m_k
1
_
{},
b_grid_desc_
k0_
n_k
1
_
{},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{},
ds_grid_desc_g_m_n_
{
DeviceOp
::
MakeDsGridDescriptor_G_M_N
(
ds_gs_ms_ns_lengths
,
ds_gs_ms_ns_strides
)},
e_grid_desc_g_m_n_
{
DeviceOp
::
MakeEGridDescriptor_G_M_N
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
)},
a_grid_desc_k0_m_k1_
{},
b_grid_desc_k0_n_k1_
{},
ds_grid_desc_mblock_mperblock_nblock_nperblock
{},
e_grid_desc_mblock_mperblock_nblock_nperblock
{},
block_2_ctile_map_
{},
...
...
@@ -600,10 +620,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds_grid
[
i
]);
});
a_grid_desc_m_k_
=
DeviceOp
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
b_grid_desc_n_k_
=
DeviceOp
::
MakeBGridDescriptor_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
);
a_grid_desc_k0_m_k1_
=
DeviceOp
::
MakeAGridDescriptor
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
b_grid_desc_k0_n_k1_
=
DeviceOp
::
MakeBGridDescriptor_K0_N_K1
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
);
ds_grid_desc_m_n_
=
DeviceOp
::
MakeDsGridDescriptor_M_N
(
ds_gs_ms_ns_lengths
,
ds_gs_ms_ns_strides
);
...
...
@@ -611,8 +629,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
e_grid_desc_m_n_
=
DeviceOp
::
MakeEGridDescriptor_M_N
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
);
a_grid_desc_k0_m_k1_
=
DeviceOp
::
MakeAGridDescriptor_K0_M_K1
(
a_grid_desc_m_k_
);
b_grid_desc_k0_n_k1_
=
DeviceOp
::
MakeBGridDescriptor_K0_N_K1
(
b_grid_desc_n_k_
);
block_2_ctile_map_
=
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
e_grid_desc_m_n_
,
M01
,
N01
);
...
...
@@ -644,15 +660,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
EDataType
*
p_e_grid_
;
// Tensor Descriptors
AGridDesc
_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
AGridDesc
a_grid_desc_
k0_
m_k
1
_
;
BGridDesc_
K0_
N_K
1
b_grid_desc_
k0_
n_k
1
_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
DsGridDesc_G_M_N
ds_grid_desc_g_m_n_
;
EGridDesc_G_M_N
e_grid_desc_g_m_n_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
...
...
@@ -712,7 +726,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
BDataType
,
typename
GridwiseOp
::
DsGridPointer
,
EDataType
,
DeviceOp
::
AGridDesc
_K0_M_K1
,
DeviceOp
::
AGridDesc
,
DeviceOp
::
BGridDesc_K0_N_K1
,
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseOp
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
@@ -975,10 +989,10 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K
0
PerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
K1
<<
", "
<<
MPerW
MMA
<<
", "
<<
NPerW
MMA
<<
", "
<<
MPerW
mma
<<
", "
<<
NPerW
mma
<<
", "
<<
MRepeat
<<
", "
<<
NRepeat
<<
">"
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
a38ce024
...
...
@@ -89,8 +89,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static
constexpr
auto
AEnableLds
=
NWaves
==
1
?
false
:
true
;
static
constexpr
auto
BEnableLds
=
MWaves
==
1
?
false
:
true
;
// static constexpr auto AEnableLds = true;
// static constexpr auto BEnableLds = true;
// Force enable LDS if uncommented following
// AEnableLds = true;
// BEnableLds = true;
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
...
...
@@ -124,7 +125,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
a_grid_desc_m_k
.
GetLength
(
I0
)
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
a38ce024
...
...
@@ -296,20 +296,27 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
constexpr
auto
KWmma
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I5
);
// Workaround, Freeze transform
return
transform_tensor_descriptor
(
ABlockDesc_
{},
make_tuple
(
make_merge_transform
(
make_tuple
(
Number
<
KWmma
>
{},
I1
)),
make_tuple
(
make_freeze_transform
(
I0
),
make_pass_through_transform
(
Number
<
KWmma
>
{}),
make_pass_through_transform
(
Number
<
MRepeat
>
{}),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Sequence
<
0
,
3
>
{},
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
}
}();
...
...
@@ -782,6 +789,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
MRepeat
,
LRepeat
,
KPack
,
AEnableLds
,
B0EnableLds
,
true
>
{};
// C' = B' x A'
...
...
@@ -968,6 +977,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
MRepeat
,
NRepeat
,
KPack
,
false
,
B1EnableLds
,
true
>
{
make_tuple
(
0
,
0
,
0
,
0
,
0
)};
auto
acc1_thread_buf
=
blockwise_gemm1
.
GetCThreadBuffer
();
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
a38ce024
...
...
@@ -69,7 +69,7 @@ __global__ void
const
auto
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
__shared__
char
p_shared
[
GridwiseOp
::
Get
SharedMem
oryNumberOfByte
()
];
__shared__
char
p_shared
[
GridwiseOp
::
SharedMem
Trait
::
lds_size
];
DsPointer
p_ds_grid_grp
;
...
...
@@ -148,7 +148,7 @@ __global__ void
const
Block2CTileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__
char
p_shared
[
GridwiseOp
::
Get
SharedMem
oryNumberOfByte
()
];
__shared__
char
p_shared
[
GridwiseOp
::
SharedMem
Trait
::
lds_size
];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
...
@@ -237,7 +237,7 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__
char
p_shared
[
GridwiseOp
::
Get
SharedMem
oryNumberOfByte
()
];
__shared__
char
p_shared
[
GridwiseOp
::
SharedMem
Trait
::
lds_size
];
GridwiseOp
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
...
...
@@ -451,20 +451,27 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
constexpr
auto
KWmma
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I5
);
// Workaround, Freeze transform
return
transform_tensor_descriptor
(
ABlockDesc_
{},
make_tuple
(
make_merge_transform
(
make_tuple
(
Number
<
KWmma
>
{},
I1
)),
make_tuple
(
make_freeze_transform
(
I0
),
make_pass_through_transform
(
Number
<
KWmma
>
{}),
make_pass_through_transform
(
Number
<
MRepeat
>
{}),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Sequence
<
0
,
3
>
{},
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
}
}();
...
...
@@ -540,19 +547,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
Number
<
NumDTensor
>
{});
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
const
index_t
gemm_bytes_end
=
SharedMemTrait
::
a_block_space_size_aligned
*
sizeof
(
ADataType
)
+
SharedMemTrait
::
b_block_space_size_aligned
*
sizeof
(
BDataType
);
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_shuffle_block_space_size
*
sizeof
(
CShuffleDataType
);
return
math
::
max
(
gemm_bytes_end
,
c_block_bytes_end
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
...
...
@@ -650,7 +644,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
(
K0
PerBlock
*
K1
)
;
const
index_t
num_loop
=
K
/
K
PerBlock
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
...
...
@@ -704,11 +698,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
static
constexpr
auto
a_block_space_size_aligned
=
AEnableLds
?
math
::
integer_least_multiple
(
MakeABlockDescriptor
().
GetElementSpaceSize
(),
max_lds_align
)
:
0
;
max_lds_align
)
:
0
;
static
constexpr
auto
b_block_space_size_aligned
=
BEnableLds
?
math
::
integer_least_multiple
(
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
().
GetElementSpaceSize
(),
max_lds_align
)
:
0
;
max_lds_align
)
:
0
;
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b_block_space_offset
=
a_block_space_size_aligned
;
...
...
@@ -719,6 +715,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
.
GetElementSpaceSize
();
static
constexpr
auto
c_shuffle_block_space_offset
=
0
;
static
constexpr
auto
lds_size
=
math
::
max
(
c_shuffle_block_space_size
*
sizeof
(
CShuffleDataType
),
a_block_space_size_aligned
*
sizeof
(
ADataType
)
+
b_block_space_size_aligned
*
sizeof
(
BDataType
));
};
using
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
...
...
@@ -796,7 +797,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatA
*>
(
p_shared
),
static_cast
<
ADataType
*>
(
p_shared
),
a_block_desc
.
GetElementSpaceSize
());
auto
a_blockwise_copy
=
...
...
@@ -807,8 +808,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
/* typename BlockSliceLengths, */
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
/* typename ThreadClusterLengths, */
ABlockTransferThreadClusterLengths_K0_M_K1
,
/* typename ThreadClusterArrangeOrder, */
ABlockTransferThreadClusterArrangeOrder
,
/* typename SrcData, */
FloatA
,
/* typename DstData, */
FloatA
,
/* typename SrcData, */
ADataType
,
/* typename DstData, */
ADataType
,
/* typename SrcDesc, */
decltype
(
a_grid_desc
),
/* typename DstDesc, */
decltype
(
a_block_desc
),
/* typename SrcDimAccessOrder, */
ABlockTransferSrcAccessOrder
,
...
...
@@ -835,13 +836,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
// Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
auto
a_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
auto
a_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
>
(
a_block_desc
.
GetElementSpaceSize
());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto
a_blockwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
FloatA
,
FloatA
,
ThreadwiseTensorSliceTransfer_v2
<
ADataType
,
ADataType
,
decltype
(
a_grid_desc
),
decltype
(
a_block_desc
),
Sequence
<
Number
<
KWmmaPerBlock
>
{},
...
...
@@ -872,7 +873,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatB
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_size_aligned
,
static_cast
<
BDataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_size_aligned
,
b_block_desc
.
GetElementSpaceSize
());
auto
b_blockwise_copy
=
...
...
@@ -883,8 +884,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatB
,
FloatB
,
BDataType
,
BDataType
,
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_block_desc
),
BBlockTransferSrcAccessOrder
,
...
...
@@ -909,11 +910,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
else
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
auto
b_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatB
>
(
auto
b_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
BDataType
>
(
b_block_desc
.
GetElementSpaceSize
());
auto
b_blockwise_copy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
ThreadwiseTensorSliceTransfer_v4
<
BDataType
,
BDataType
,
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_block_desc
),
Sequence
<
Number
<
K0PerBlock
>
{},
...
...
@@ -952,38 +953,35 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
NPerWmma
,
MRepeat
,
NRepeat
,
KPack
>
{};
KPack
,
AEnableLds
,
BEnableLds
>
{};
// Prepare Register for C matrix
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
/*******************************************************************************/
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
// LDS allocation for A and B: be careful of alignment
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ADataType
*>
(
p_shared
),
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
BDataType
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_k0perblock_nperblock_k1
.
GetElementSpaceSize
());
// Shift Per SUB_K
constexpr
auto
a_block_slice_copy_step
=
m
ake
_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
m
ake
_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
M
ake
ABlockSliceCopyStep
(
);
constexpr
auto
b_block_slice_copy_step
=
M
ake
BBlockSliceCopyStep
(
);
// gridwise GEMM pipeline
const
index_t
K
0
BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K
0
/
K
0
PerBlock
);
const
index_t
KBlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K
/
KPerBlock
);
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc
,
a_block_desc
_k0perblock_mperblock_k1
,
a_block_desc
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_k0_n_k1
,
b_block_desc
_k0perblock_nperblock_k1
,
b_block_desc
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
K
0
BlockMainLoop
);
KBlockMainLoop
);
/*******************************************************************************/
// write out to C, implement shuffle
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
a38ce024
...
...
@@ -56,8 +56,6 @@ struct GridwiseGemmPipeline_v1<1, true, true>
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
if
(
get_thread_local_1d_id
()
<
32
);
printf
(
"Mat-A Lds Enabled, Mat-B Lds Enabled
\n
"
);
// preload data into LDS
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
...
...
@@ -306,8 +304,6 @@ struct GridwiseGemmPipeline_v1<1, false, true>
},
Number<a_block_desc.GetLengths().GetSize()>{});
#endif
if
(
get_thread_local_1d_id
()
<
32
);
printf
(
"Mat-A Lds Disabled, Mat-B Lds Enabled
\n
"
);
constexpr
auto
a_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
auto
a_block_buf_switch
=
a_block_buf
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
a38ce024
...
...
@@ -731,7 +731,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr
auto
b_block_slice_copy_step
=
MakeBBlockSliceCopyStep
();
// gridwise GEMM pipeline
const
index_t
K
0
BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K
/
KPerBlock
);
const
index_t
KBlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K
/
KPerBlock
);
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc
,
a_block_desc
,
a_blockwise_copy
,
...
...
@@ -746,7 +746,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
K
0
BlockMainLoop
);
KBlockMainLoop
);
/*******************************************************************************/
// write out to C, implement shuffle
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment