Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8dbb73b1
"example/01_gemm/gemm_dl_dpp8_fp16.cpp" did not exist on "b94fd0b2279c6476c6e109e99dc5d0e6d8ce313c"
Commit
8dbb73b1
authored
Feb 16, 2023
by
aska-0096
Browse files
format
parent
cc6a534f
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
186 additions
and
171 deletions
+186
-171
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
..._bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
+49
-49
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+7
-3
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
...emm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
+11
-11
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
...tmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
+1
-1
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+13
-21
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
...l/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
+15
-12
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+22
-23
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+26
-23
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+9
-7
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+25
-12
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+1
-1
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+3
-4
script/clang-format-overwrite.sh
script/clang-format-overwrite.sh
+2
-2
No files found.
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
View file @
8dbb73b1
...
@@ -50,8 +50,7 @@ static constexpr auto BSpec = ck::tensor_operation::device::TensorSpecialization
...
@@ -50,8 +50,7 @@ static constexpr auto BSpec = ck::tensor_operation::device::TensorSpecialization
static
constexpr
auto
DESpec
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
DESpec
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceOpInstanceKKNN
=
using
DeviceOpInstanceKKNN
=
ck
::
tensor_operation
::
device
::
DeviceBatchedContractionMultipleD_Wmma_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedContractionMultipleD_Wmma_CShuffle
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
NumDimK
,
NumDimK
,
...
@@ -327,7 +326,8 @@ int main(int argc, char* argv[])
...
@@ -327,7 +326,8 @@ int main(int argc, char* argv[])
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_gs_ms_ns_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_gs_ms_ns_device_result
.
mDesc
.
GetElementSpaceSize
());
a_device_buf
.
ToDevice
(
a_gs_ms_ks
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_gs_ms_ks
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_gs_ns_ks
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_gs_ns_ks
.
mData
.
data
());
...
@@ -379,7 +379,7 @@ int main(int argc, char* argv[])
...
@@ -379,7 +379,7 @@ int main(int argc, char* argv[])
ck
::
index_t
K
=
ck
::
accumulate_n
<
ck
::
index_t
>
(
ck
::
index_t
K
=
ck
::
accumulate_n
<
ck
::
index_t
>
(
a_gs_ms_ks_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
NumDimK
,
1
,
std
::
multiplies
<>
{});
a_gs_ms_ks_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
NumDimK
,
1
,
std
::
multiplies
<>
{});
std
::
cout
<<
"GMNK="
<<
G
<<
", "
<<
M
<<
", "
<<
N
<<
", "
<<
K
<<
std
::
endl
;
std
::
cout
<<
"GMNK="
<<
G
<<
", "
<<
M
<<
", "
<<
N
<<
", "
<<
K
<<
std
::
endl
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
G
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
G
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
G
*
M
*
K
+
sizeof
(
BDataType
)
*
G
*
K
*
N
+
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
G
*
M
*
K
+
sizeof
(
BDataType
)
*
G
*
K
*
N
+
sizeof
(
DDataType
)
*
G
*
M
*
N
+
sizeof
(
EDataType
)
*
G
*
M
*
N
;
sizeof
(
DDataType
)
*
G
*
M
*
N
+
sizeof
(
EDataType
)
*
G
*
M
*
N
;
...
...
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
8dbb73b1
...
@@ -5,7 +5,9 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16
...
@@ -5,7 +5,9 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16
add_example_executable
(
example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
)
if
(
GPU_TARGETS MATCHES
"gfx1100"
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
)
endif
()
add_custom_target
(
example_gemm_scale_softmax_gemm
)
add_custom_target
(
example_gemm_scale_softmax_gemm
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16
)
...
@@ -16,5 +18,7 @@ add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_soft
...
@@ -16,5 +18,7 @@ add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_soft
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16
)
add_custom_target
(
example_gemm_scale_softmax_gemm_wmma
)
if
(
GPU_TARGETS MATCHES
"gfx1100"
)
add_dependencies
(
example_gemm_scale_softmax_gemm_wmma example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16
)
add_custom_target
(
example_gemm_scale_softmax_gemm_wmma
)
add_dependencies
(
example_gemm_scale_softmax_gemm_wmma example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16
)
endif
()
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
View file @
8dbb73b1
...
@@ -104,7 +104,7 @@ using DeviceGemmInstance =
...
@@ -104,7 +104,7 @@ using DeviceGemmInstance =
16
,
// MPerWMMA
16
,
// MPerWMMA
16
,
// LPerWMMA
16
,
// LPerWMMA
16
,
// NPerWMMA
16
,
// NPerWMMA
//Per repeat = wave_m = wave_num, wave_n = 1
//
Per repeat = wave_m = wave_num, wave_n = 1
1
,
// MRepeat
1
,
// MRepeat
8
,
// LRepeat
8
,
// LRepeat
4
,
// NRepeat
4
,
// NRepeat
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
View file @
8dbb73b1
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
8dbb73b1
...
@@ -62,7 +62,8 @@ struct BlockwiseGemmWMMA
...
@@ -62,7 +62,8 @@ struct BlockwiseGemmWMMA
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I4
);
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I4
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I4
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I4
);
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
>
{};
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
>
{};
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
...
@@ -149,13 +150,8 @@ struct BlockwiseGemmWMMA
...
@@ -149,13 +150,8 @@ struct BlockwiseGemmWMMA
const
auto
blk_idx
=
wmma_gemm
.
GetBeginOfThreadBlk3D
();
const
auto
blk_idx
=
wmma_gemm
.
GetBeginOfThreadBlk3D
();
return
make_tuple
(
Number
<
m0
>
{},
return
make_tuple
(
blk_idx
[
I0
],
Number
<
m0
>
{},
blk_idx
[
I0
],
waveId_m
,
Number
<
n0
>
{},
waveId_n
,
blk_idx
[
I1
],
blk_idx
[
I2
]);
waveId_m
,
Number
<
n0
>
{},
waveId_n
,
blk_idx
[
I1
],
blk_idx
[
I2
]);
}
}
using
Tuple5
=
decltype
(
CalculateAThreadOriginDataIndex
());
using
Tuple5
=
decltype
(
CalculateAThreadOriginDataIndex
());
...
@@ -169,7 +165,8 @@ struct BlockwiseGemmWMMA
...
@@ -169,7 +165,8 @@ struct BlockwiseGemmWMMA
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize
\n
"
);
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
MPerBlock
%
(
MPerWMMA
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
static_assert
(
MPerBlock
%
(
MPerWMMA
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
"wrong!"
);
"wrong!"
);
}
}
...
@@ -180,20 +177,15 @@ struct BlockwiseGemmWMMA
...
@@ -180,20 +177,15 @@ struct BlockwiseGemmWMMA
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
// constexpr auto NSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
// constexpr auto NSubGroup =
// constexpr auto MThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
// c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; constexpr auto MThreadPerSubGroup
// = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr
auto
NAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
constexpr
auto
NAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
return
make_naive_tensor_descriptor_packed
(
return
make_naive_tensor_descriptor_packed
(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
// |NThreadPerSubGroup |MAccVgprs
make_tuple
(
Number
<
MRepeat
>
{},
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
NRepeat
>
{},
I1
,
I1
,
NAccVgprs
));
I1
,
I1
,
Number
<
NRepeat
>
{},
I1
,
I1
,
NAccVgprs
));
}
}
// Thread level, register decriptor. Vector-write
// Thread level, register decriptor. Vector-write
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
View file @
8dbb73b1
...
@@ -605,9 +605,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
@@ -605,9 +605,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
b_grid_desc_n_k_
=
b_grid_desc_n_k_
=
DeviceOp
::
MakeBGridDescriptor_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
);
DeviceOp
::
MakeBGridDescriptor_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
);
ds_grid_desc_m_n_
=
DeviceOp
::
MakeDsGridDescriptor_M_N
(
ds_gs_ms_ns_lengths
,
ds_gs_ms_ns_strides
);
ds_grid_desc_m_n_
=
DeviceOp
::
MakeDsGridDescriptor_M_N
(
ds_gs_ms_ns_lengths
,
ds_gs_ms_ns_strides
);
e_grid_desc_m_n_
=
DeviceOp
::
MakeEGridDescriptor_M_N
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
);
e_grid_desc_m_n_
=
DeviceOp
::
MakeEGridDescriptor_M_N
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
);
a_grid_desc_k0_m_k1_
=
DeviceOp
::
MakeAGridDescriptor_K0_M_K1
(
a_grid_desc_m_k_
);
a_grid_desc_k0_m_k1_
=
DeviceOp
::
MakeAGridDescriptor_K0_M_K1
(
a_grid_desc_m_k_
);
b_grid_desc_k0_n_k1_
=
DeviceOp
::
MakeBGridDescriptor_K0_N_K1
(
b_grid_desc_n_k_
);
b_grid_desc_k0_n_k1_
=
DeviceOp
::
MakeBGridDescriptor_K0_N_K1
(
b_grid_desc_n_k_
);
...
@@ -619,8 +621,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
@@ -619,8 +621,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
ds_grid_desc_m_n_
);
ds_grid_desc_m_n_
);
e_grid_desc_mblock_mperblock_nblock_nperblock
=
e_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseOp
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseOp
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
e_grid_desc_m_n_
);
// for sanity check of vector memory access
// for sanity check of vector memory access
a_mz_stride_
=
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
-
1
];
a_mz_stride_
=
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
-
1
];
...
@@ -696,9 +697,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
@@ -696,9 +697,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
{
{
const
index_t
G
=
arg
.
e_grid_desc_g_m_n_
.
GetLength
(
I0
);
const
index_t
G
=
arg
.
e_grid_desc_g_m_n_
.
GetLength
(
I0
);
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
)
*
G
;
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
)
*
G
;
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
8dbb73b1
...
@@ -339,10 +339,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -339,10 +339,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc_ak0_m_ak1_
{
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b0_grid_desc_bk0_l_bk1_
{
b0_grid_desc_bk0_l_bk1_
{
DeviceOp
::
MakeB0GridDescriptor_BK0_L_BK1
(
DeviceOp
::
MakeB0GridDescriptor_BK0_L_BK1
(
b0_gs_ls_ks_lengths
,
b0_gs_ls_ks_strides
)},
b0_gs_ls_ks_lengths
,
b0_gs_ls_ks_strides
)},
b1_grid_desc_bl0_n_bl1_
{
b1_grid_desc_bl0_n_bl1_
{
DeviceOp
::
MakeB1GridDescriptor_BL0_N_BL1
(
DeviceOp
::
MakeB1GridDescriptor_BL0_N_BL1
(
b1_gs_ns_ls_lengths
,
b1_gs_ns_ls_strides
)},
b1_gs_ns_ls_lengths
,
b1_gs_ns_ls_strides
)},
c_grid_desc_m_n_
{
c_grid_desc_m_n_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_ns_lengths
,
c_gs_ms_ns_strides
)},
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_ns_lengths
,
c_gs_ms_ns_strides
)},
a_grid_desc_g_m_k_
{
a_grid_desc_g_m_k_
{
...
@@ -450,9 +450,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -450,9 +450,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
batch_count_
;
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
batch_count_
;
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
const
auto
kernel
=
kernel_batched_gemm_softmax_gemm_wmma_cshuffle
<
const
auto
kernel
=
kernel_batched_gemm_softmax_gemm_wmma_cshuffle
<
...
@@ -592,8 +594,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -592,8 +594,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B0BlockTransferSrcVectorDim
==
2
?
arg
.
b0_lz_kz_strides_
[
1
]
:
arg
.
b0_lz_kz_strides_
[
0
];
B0BlockTransferSrcVectorDim
==
2
?
arg
.
b0_lz_kz_strides_
[
1
]
:
arg
.
b0_lz_kz_strides_
[
0
];
const
auto
b1_stride_lowest
=
const
auto
b1_stride_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
arg
.
b1_nz_lz_strides_
[
1
]
:
arg
.
b1_nz_lz_strides_
[
0
];
B1BlockTransferSrcVectorDim
==
2
?
arg
.
b1_nz_lz_strides_
[
1
]
:
arg
.
b1_nz_lz_strides_
[
0
];
const
auto
c_stride_lowest
=
const
auto
c_stride_lowest
=
arg
.
c_mz_nz_strides_
[
1
];
arg
.
c_mz_nz_strides_
[
1
];
if
(
!
(
a_stride_lowest
==
1
||
b0_stride_lowest
==
1
||
b1_stride_lowest
==
1
||
if
(
!
(
a_stride_lowest
==
1
||
b0_stride_lowest
==
1
||
b1_stride_lowest
==
1
||
c_stride_lowest
==
1
))
c_stride_lowest
==
1
))
...
@@ -610,8 +611,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -610,8 +611,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
static
auto
static
auto
MakeArgument
(
MakeArgument
(
const
ADataType
*
p_a
,
const
ADataType
*
p_a
,
const
B0DataType
*
p_b0
,
const
B0DataType
*
p_b0
,
const
B1DataType
*
p_b1
,
const
B1DataType
*
p_b1
,
...
@@ -664,8 +664,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -664,8 +664,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}
}
// polymorphic
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_a
,
const
void
*
p_b0
,
const
void
*
p_b0
,
const
void
*
p_b1
,
const
void
*
p_b1
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
8dbb73b1
...
@@ -250,15 +250,15 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -250,15 +250,15 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
A1BlockDesc_AL0_M_AL1
{},
A1BlockDesc_AL0_M_AL1
{},
make_tuple
(
make_pass_through_transform
(
Number
<
A_L0
>
{}),
make_tuple
(
make_pass_through_transform
(
Number
<
A_L0
>
{}),
make_unmerge_transform
(
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
)),
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
)),
make_pass_through_transform
(
Number
<
A_L1
>
{})),
make_pass_through_transform
(
Number
<
A_L1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
}
template
<
typename
B1BlockDesc_BL0_N_BL1
>
template
<
typename
B1BlockDesc_BL0_N_BL1
>
__host__
__device__
static
constexpr
auto
MakeB1BlockDescriptor_L0_N0_N1_N2_L1
(
const
B1BlockDesc_BL0_N_BL1
&
)
__host__
__device__
static
constexpr
auto
MakeB1BlockDescriptor_L0_N0_N1_N2_L1
(
const
B1BlockDesc_BL0_N_BL1
&
)
{
{
constexpr
index_t
B_K0
=
B1BlockDesc_BL0_N_BL1
{}.
GetLength
(
I0
);
constexpr
index_t
B_K0
=
B1BlockDesc_BL0_N_BL1
{}.
GetLength
(
I0
);
constexpr
index_t
B_K1
=
B1BlockDesc_BL0_N_BL1
{}.
GetLength
(
I2
);
constexpr
index_t
B_K1
=
B1BlockDesc_BL0_N_BL1
{}.
GetLength
(
I2
);
...
@@ -317,7 +317,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -317,7 +317,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
{
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a_block_space_size_aligned
*
sizeof
(
FloatA
)
+
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a_block_space_size_aligned
*
sizeof
(
FloatA
)
+
SharedMemTrait
::
b0_block_space_size_aligned
*
sizeof
(
FloatB0
));
SharedMemTrait
::
b0_block_space_size_aligned
*
sizeof
(
FloatB0
));
const
index_t
gemm1_bytes_end
=
const
index_t
gemm1_bytes_end
=
...
@@ -360,8 +361,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -360,8 +361,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
return
false
;
return
false
;
}
}
if
(
!
(
M
%
MPerBlock
==
0
&&
L
%
LPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
if
(
!
(
M
%
MPerBlock
==
0
&&
L
%
LPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
N
%
NPerBlock
==
0
))
N
%
NPerBlock
==
0
))
{
{
return
false
;
return
false
;
}
}
...
@@ -466,10 +466,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -466,10 +466,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
static
constexpr
auto
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
=
static
constexpr
auto
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
=
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
();
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
();
static
constexpr
auto
c_block_space_size
=
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.
GetElementSpaceSize
();
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.
GetElementSpaceSize
();
};
};
template
<
bool
HasMainKBlockLoop
,
typename
C0MatrixMask
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
C0MatrixMask
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatB0
*
__restrict__
p_b0_grid
,
const
FloatB0
*
__restrict__
p_b0_grid
,
const
FloatB1
*
__restrict__
p_b1_grid
,
const
FloatB1
*
__restrict__
p_b1_grid
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
8dbb73b1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
8dbb73b1
...
@@ -141,7 +141,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -141,7 +141,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
template
<
typename
ABlockDesc_AK0_M_AK1
>
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_K0_M0_M1_M2_K1
(
const
ABlockDesc_AK0_M_AK1
&
)
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_K0_M0_M1_M2_K1
(
const
ABlockDesc_AK0_M_AK1
&
)
{
{
constexpr
index_t
A_K0
=
ABlockDesc_AK0_M_AK1
{}.
GetLength
(
I0
);
constexpr
index_t
A_K0
=
ABlockDesc_AK0_M_AK1
{}.
GetLength
(
I0
);
constexpr
index_t
A_K1
=
ABlockDesc_AK0_M_AK1
{}.
GetLength
(
I2
);
constexpr
index_t
A_K1
=
ABlockDesc_AK0_M_AK1
{}.
GetLength
(
I2
);
...
@@ -158,7 +159,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -158,7 +159,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
}
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_K0_N0_N1_N2_K1
(
const
BBlockDesc_BK0_N_BK1
&
)
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_K0_N0_N1_N2_K1
(
const
BBlockDesc_BK0_N_BK1
&
)
{
{
constexpr
index_t
B_K0
=
BBlockDesc_BK0_N_BK1
{}.
GetLength
(
I0
);
constexpr
index_t
B_K0
=
BBlockDesc_BK0_N_BK1
{}.
GetLength
(
I0
);
constexpr
index_t
B_K1
=
BBlockDesc_BK0_N_BK1
{}.
GetLength
(
I2
);
constexpr
index_t
B_K1
=
BBlockDesc_BK0_N_BK1
{}.
GetLength
(
I2
);
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
8dbb73b1
...
@@ -1398,24 +1398,37 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
...
@@ -1398,24 +1398,37 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
element_op_
(
v_this_row
,
src_buf
[
Number
<
src_offset
>
{}]);
element_op_
(
v_this_row
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply intra-row swizzle permute
// apply intra-row swizzle permute
if
constexpr
(
IntraRowSwizzlePerm
){
if
constexpr
(
IntraRowSwizzlePerm
)
// origin: 0xfedcba98, 0x76543210
{
temp
=
__builtin_amdgcn_permlane16
(
temp
,
type_convert
<
int
>
(
v_this_row
),
0xeca86420
,
0xfdb97531
,
1
,
0
);
// origin:
// 0xfedcba98,
// 0x76543210
temp
=
__builtin_amdgcn_permlane16
(
temp
,
type_convert
<
int
>
(
v_this_row
),
0xeca86420
,
0xfdb97531
,
1
,
0
);
v_this_row
=
type_convert
<
float
>
(
temp
);
v_this_row
=
type_convert
<
float
>
(
temp
);
}
}
// apply inter-row permute.
// apply inter-row permute.
temp
=
__builtin_amdgcn_permlanex16
(
temp
,
type_convert
<
int
>
(
v_this_row
),
LowEightRowlaneIdx
,
HighEightRowLaneIdx
,
1
,
0
);
temp
=
__builtin_amdgcn_permlanex16
(
temp
,
type_convert
<
int
>
(
v_this_row
),
LowEightRowlaneIdx
,
HighEightRowLaneIdx
,
1
,
0
);
v_theother_row
=
type_convert
<
float
>
(
temp
);
v_theother_row
=
type_convert
<
float
>
(
temp
);
if
(
get_thread_local_1d_id
()
%
32
<
16
){
if
(
get_thread_local_1d_id
()
%
32
<
16
)
{
// apply type convert
// apply type convert
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
v_this_row
);
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
v_this_row
);
dst_buf
(
Number
<
dst_offset
+
DstScalarPerVector
>
{})
=
type_convert
<
DstData
>
(
v_theother_row
);
dst_buf
(
Number
<
dst_offset
+
DstScalarPerVector
>
{})
=
type_convert
<
DstData
>
(
v_theother_row
);
}
}
else
{
else
{
// apply type convert
// apply type convert
dst_buf
(
Number
<
dst_offset
+
DstScalarPerVector
>
{})
=
type_convert
<
DstData
>
(
v_this_row
);
dst_buf
(
Number
<
dst_offset
+
DstScalarPerVector
>
{})
=
type_convert
<
DstData
>
(
v_this_row
);
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
v_theother_row
);
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
v_theother_row
);
}
}
});
});
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
8dbb73b1
include/ck/utility/amd_wmma.hpp
View file @
8dbb73b1
...
@@ -25,9 +25,8 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
...
@@ -25,9 +25,8 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
// delete them.
// delete them.
// amd_assembly_wmma_f32_16x16x16_f16_w32(
// amd_assembly_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32
(
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
}
}
};
};
...
...
script/clang-format-overwrite.sh
View file @
8dbb73b1
#
find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}'
find
.
-name
deps
-prune
-o
-name
build
-prune
-o
-iname
'*.h'
-o
-iname
'*.hpp'
-o
-iname
'*.cpp'
-o
-iname
'*.h.in'
-o
-iname
'*.hpp.in'
-o
-iname
'*.cpp.in'
-o
-iname
'*.cl'
-o
-iname
'*.cuh'
-o
-iname
'*.cu'
-o
-iname
'*.inc'
| xargs
-n
1
-P
16
-I
{}
-t
sh
-c
'clang-format-10 -i -style=file {}'
git status
--porcelain
|
awk
'$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}'
| xargs
-n
1
-P
16
-I
{}
-t
sh
-c
'clang-format-10 -i -style=file {}'
#
git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}'
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment