Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8dbb73b1
"driver/src/host_tensor.cpp" did not exist on "1b648f2f42bf5b82421289cff350ac7af6ec46ea"
Commit
8dbb73b1
authored
Feb 16, 2023
by
aska-0096
Browse files
format
parent
cc6a534f
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
186 additions
and
171 deletions
+186
-171
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
..._bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
+49
-49
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+7
-3
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
...emm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
+11
-11
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
...tmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
+1
-1
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+13
-21
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
...l/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
+15
-12
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+22
-23
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+26
-23
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+9
-7
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+25
-12
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+1
-1
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+3
-4
script/clang-format-overwrite.sh
script/clang-format-overwrite.sh
+2
-2
No files found.
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
View file @
8dbb73b1
...
...
@@ -45,56 +45,55 @@ using CDEElementOp = ck::tensor_operation::element_wise::Add;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
ASpec
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
BSpec
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
ASpec
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
BSpec
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
DESpec
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceOpInstanceKKNN
=
ck
::
tensor_operation
::
device
::
DeviceBatchedContractionMultipleD_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
ASpec
,
BSpec
,
DESpec
,
256
,
128
,
128
,
4
,
8
,
16
,
16
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
ck
::
tensor_operation
::
device
::
DeviceBatchedContractionMultipleD_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
ASpec
,
BSpec
,
DESpec
,
256
,
128
,
128
,
4
,
8
,
16
,
16
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
using
DeviceOpInstance
=
DeviceOpInstanceKKNN
;
...
...
@@ -327,7 +326,8 @@ int main(int argc, char* argv[])
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_gs_ms_ns_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_gs_ms_ns_device_result
.
mDesc
.
GetElementSpaceSize
());
a_device_buf
.
ToDevice
(
a_gs_ms_ks
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_gs_ns_ks
.
mData
.
data
());
...
...
@@ -379,7 +379,7 @@ int main(int argc, char* argv[])
ck
::
index_t
K
=
ck
::
accumulate_n
<
ck
::
index_t
>
(
a_gs_ms_ks_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
NumDimK
,
1
,
std
::
multiplies
<>
{});
std
::
cout
<<
"GMNK="
<<
G
<<
", "
<<
M
<<
", "
<<
N
<<
", "
<<
K
<<
std
::
endl
;
std
::
cout
<<
"GMNK="
<<
G
<<
", "
<<
M
<<
", "
<<
N
<<
", "
<<
K
<<
std
::
endl
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
G
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
G
*
M
*
K
+
sizeof
(
BDataType
)
*
G
*
K
*
N
+
sizeof
(
DDataType
)
*
G
*
M
*
N
+
sizeof
(
EDataType
)
*
G
*
M
*
N
;
...
...
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
8dbb73b1
...
...
@@ -5,7 +5,9 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16
add_example_executable
(
example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
)
if
(
GPU_TARGETS MATCHES
"gfx1100"
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
)
endif
()
add_custom_target
(
example_gemm_scale_softmax_gemm
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16
)
...
...
@@ -16,5 +18,7 @@ add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_soft
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16
)
add_custom_target
(
example_gemm_scale_softmax_gemm_wmma
)
add_dependencies
(
example_gemm_scale_softmax_gemm_wmma example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16
)
if
(
GPU_TARGETS MATCHES
"gfx1100"
)
add_custom_target
(
example_gemm_scale_softmax_gemm_wmma
)
add_dependencies
(
example_gemm_scale_softmax_gemm_wmma example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16
)
endif
()
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
View file @
8dbb73b1
...
...
@@ -94,17 +94,17 @@ using DeviceGemmInstance =
TensorSpecB1
,
TensorSpecC
,
256
,
128
,
// MPerBlock
128
,
// LPerBlock
4
,
// K0PerBlock
8
,
// K1
64
,
// NPerBlock
4
,
// L0PerBlock
8
,
// L1
16
,
// MPerWMMA
16
,
// LPerWMMA
16
,
// NPerWMMA
//Per repeat = wave_m = wave_num, wave_n = 1
128
,
// MPerBlock
128
,
// LPerBlock
4
,
// K0PerBlock
8
,
// K1
64
,
// NPerBlock
4
,
// L0PerBlock
8
,
// L1
16
,
// MPerWMMA
16
,
// LPerWMMA
16
,
// NPerWMMA
//
Per repeat = wave_m = wave_num, wave_n = 1
1
,
// MRepeat
8
,
// LRepeat
4
,
// NRepeat
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
View file @
8dbb73b1
...
...
@@ -218,7 +218,7 @@ int run(int argc, char* argv[])
Tensor
<
ADataType
>
a_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
B0DataType
>
b0_g_k_n
({
BatchCount
,
K
,
N
});
Tensor
<
B1DataType
>
b1_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
Acc0DataType
>
acc0_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after gemm0
Tensor
<
Acc0DataType
>
acc0_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after gemm0
Tensor
<
ADataType
>
a1_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after softmax
Tensor
<
CDataType
>
c_g_m_o_host_result
({
BatchCount
,
M
,
O
});
// scratch object after gemm1
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
8dbb73b1
...
...
@@ -33,10 +33,10 @@ template <index_t BlockSize,
* B: K0PerBlock x NPerBlock x K1
* Destination
* C, non-transpose
* thread level: MRepeat x NRepeat x MAccVgprs
* thread level: MRepeat x NRepeat x MAccVgprs
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* KPACK == WMMA_K = 16
*
*
* Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
* Source:
* A(if skip LDS): MRepeat x KPack
...
...
@@ -62,7 +62,8 @@ struct BlockwiseGemmWMMA
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I4
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I4
);
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
>
{};
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
>
{};
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
...
...
@@ -149,13 +150,8 @@ struct BlockwiseGemmWMMA
const
auto
blk_idx
=
wmma_gemm
.
GetBeginOfThreadBlk3D
();
return
make_tuple
(
Number
<
m0
>
{},
blk_idx
[
I0
],
waveId_m
,
Number
<
n0
>
{},
waveId_n
,
blk_idx
[
I1
],
blk_idx
[
I2
]);
return
make_tuple
(
Number
<
m0
>
{},
blk_idx
[
I0
],
waveId_m
,
Number
<
n0
>
{},
waveId_n
,
blk_idx
[
I1
],
blk_idx
[
I2
]);
}
using
Tuple5
=
decltype
(
CalculateAThreadOriginDataIndex
());
...
...
@@ -169,7 +165,8 @@ struct BlockwiseGemmWMMA
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
MPerBlock
%
(
MPerWMMA
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
static_assert
(
MPerBlock
%
(
MPerWMMA
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
"wrong!"
);
}
...
...
@@ -180,20 +177,15 @@ struct BlockwiseGemmWMMA
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
// constexpr auto NSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
// constexpr auto MThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr
auto
NAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
// constexpr auto NSubGroup =
// c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; constexpr auto MThreadPerSubGroup
// = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr
auto
NAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
return
make_naive_tensor_descriptor_packed
(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
NRepeat
>
{},
I1
,
I1
,
NAccVgprs
));
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
NRepeat
>
{},
I1
,
I1
,
NAccVgprs
));
}
// Thread level, register decriptor. Vector-write
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
View file @
8dbb73b1
...
...
@@ -393,10 +393,10 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
}
// Gridwise descriptor, mapping to whole given provblem.
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
({},
{}));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
({},
{}));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
({},
{}));
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
({},
{}));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
({},
{}));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
({},
{}));
using
DsGridDesc_G_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_G_M_N
({},
{}))
>
;
using
EGridDesc_G_M_N
=
decltype
(
MakeEGridDescriptor_G_M_N
({},
{}));
...
...
@@ -604,10 +604,12 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
DeviceOp
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
b_grid_desc_n_k_
=
DeviceOp
::
MakeBGridDescriptor_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
);
ds_grid_desc_m_n_
=
DeviceOp
::
MakeDsGridDescriptor_M_N
(
ds_gs_ms_ns_lengths
,
ds_gs_ms_ns_strides
);
e_grid_desc_m_n_
=
DeviceOp
::
MakeEGridDescriptor_M_N
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
);
ds_grid_desc_m_n_
=
DeviceOp
::
MakeDsGridDescriptor_M_N
(
ds_gs_ms_ns_lengths
,
ds_gs_ms_ns_strides
);
e_grid_desc_m_n_
=
DeviceOp
::
MakeEGridDescriptor_M_N
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
);
a_grid_desc_k0_m_k1_
=
DeviceOp
::
MakeAGridDescriptor_K0_M_K1
(
a_grid_desc_m_k_
);
b_grid_desc_k0_n_k1_
=
DeviceOp
::
MakeBGridDescriptor_K0_N_K1
(
b_grid_desc_n_k_
);
...
...
@@ -619,8 +621,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
ds_grid_desc_m_n_
);
e_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseOp
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
GridwiseOp
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
// for sanity check of vector memory access
a_mz_stride_
=
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
-
1
];
...
...
@@ -696,9 +697,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
{
const
index_t
G
=
arg
.
e_grid_desc_g_m_n_
.
GetLength
(
I0
);
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
)
*
G
;
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
)
*
G
;
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
8dbb73b1
...
...
@@ -54,10 +54,10 @@ template <index_t NumDimG,
ck
::
index_t
MPerBlock
,
ck
::
index_t
LPerBlock
,
ck
::
index_t
K0PerBlock
,
// K0 * K1 = Gemm0 GEMM_K Dim
ck
::
index_t
K1
,
//
ck
::
index_t
K1
,
//
ck
::
index_t
NPerBlock
,
ck
::
index_t
L0PerBlock
,
ck
::
index_t
L1
,
ck
::
index_t
L1
,
ck
::
index_t
MPerWMMA
,
ck
::
index_t
LPerWMMA
,
ck
::
index_t
NPerWMMA
,
...
...
@@ -136,7 +136,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
Transform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
Sequence
<
NumDimG
,
NumDimM
,
NumDimL
,
NumDimK
,
NumDimN
>
,
Sequence
<
MPerBlock
,
LPerBlock
,
KPerBlock
,
NPerBlock
>
,
...
...
@@ -261,7 +261,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
K1
,
//
NPerBlock
,
L0PerBlock
,
L1
,
L1
,
MPerWMMA
,
LPerWMMA
,
NPerWMMA
,
...
...
@@ -339,10 +339,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
p_c_grid_
{
p_c_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b0_grid_desc_bk0_l_bk1_
{
DeviceOp
::
MakeB0GridDescriptor_BK0_L_BK1
(
b0_gs_ls_ks_lengths
,
b0_gs_ls_ks_strides
)},
b1_grid_desc_bl0_n_bl1_
{
DeviceOp
::
MakeB1GridDescriptor_BL0_N_BL1
(
b1_gs_ns_ls_lengths
,
b1_gs_ns_ls_strides
)},
b0_grid_desc_bk0_l_bk1_
{
DeviceOp
::
MakeB0GridDescriptor_BK0_L_BK1
(
b0_gs_ls_ks_lengths
,
b0_gs_ls_ks_strides
)},
b1_grid_desc_bl0_n_bl1_
{
DeviceOp
::
MakeB1GridDescriptor_BL0_N_BL1
(
b1_gs_ns_ls_lengths
,
b1_gs_ns_ls_strides
)},
c_grid_desc_m_n_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_ns_lengths
,
c_gs_ms_ns_strides
)},
a_grid_desc_g_m_k_
{
...
...
@@ -408,7 +408,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B0GridDesc_BK0_L_BK1
b0_grid_desc_bk0_l_bk1_
;
B1GridDesc_BL0_N_BL1
b1_grid_desc_bl0_n_bl1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
B0GridDesc_G_L_K
b0_grid_desc_g_l_k_
;
B1GridDesc_G_N_L
b1_grid_desc_g_n_l_
;
...
...
@@ -450,9 +450,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
batch_count_
;
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
batch_count_
;
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
const
auto
kernel
=
kernel_batched_gemm_softmax_gemm_wmma_cshuffle
<
...
...
@@ -552,11 +554,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
c_m
=
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
);
const
index_t
c_n
=
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
);
const
index_t
a_m
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_n
=
arg
.
b1_grid_desc_bl0_n_bl1_
.
GetLength
(
I1
);
const
index_t
c_g
=
arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
c_m
=
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
);
const
index_t
c_n
=
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
);
const
index_t
a_m
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_n
=
arg
.
b1_grid_desc_bl0_n_bl1_
.
GetLength
(
I1
);
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_n
==
b1_n
))
{
...
...
@@ -592,8 +594,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B0BlockTransferSrcVectorDim
==
2
?
arg
.
b0_lz_kz_strides_
[
1
]
:
arg
.
b0_lz_kz_strides_
[
0
];
const
auto
b1_stride_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
arg
.
b1_nz_lz_strides_
[
1
]
:
arg
.
b1_nz_lz_strides_
[
0
];
const
auto
c_stride_lowest
=
arg
.
c_mz_nz_strides_
[
1
];
const
auto
c_stride_lowest
=
arg
.
c_mz_nz_strides_
[
1
];
if
(
!
(
a_stride_lowest
==
1
||
b0_stride_lowest
==
1
||
b1_stride_lowest
==
1
||
c_stride_lowest
==
1
))
...
...
@@ -610,8 +611,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
B0DataType
*
p_b0
,
const
B1DataType
*
p_b1
,
...
...
@@ -634,7 +634,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B0ElementwiseOperation
b0_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
)
{
return
Argument
{
p_a
,
p_b0
,
...
...
@@ -664,8 +664,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b0
,
const
void
*
p_b1
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
8dbb73b1
...
...
@@ -135,10 +135,10 @@ template <typename FloatA,
index_t
MPerBlock
,
index_t
LPerBlock
,
index_t
K0PerBlock
,
// K0 * K1Value = Gemm0 GEMM_K Dim
index_t
K1Value
,
index_t
K1Value
,
index_t
NPerBlock
,
index_t
L0PerBlock
,
index_t
L1Value
,
index_t
L1Value
,
index_t
MPerWmma
,
index_t
LPerWmma
,
index_t
NPerWmma
,
...
...
@@ -209,8 +209,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
__host__
__device__
static
constexpr
auto
MakeA0BlockDescriptor_K0_M0_M1_M2_K1
(
const
A0BlockDesc_AK0_M_AK1
&
)
{
constexpr
index_t
A_K0
=
A0BlockDesc_AK0_M_AK1
{}.
GetLength
(
I0
);
constexpr
index_t
A_K1
=
A0BlockDesc_AK0_M_AK1
{}.
GetLength
(
I2
);
constexpr
index_t
A_K0
=
A0BlockDesc_AK0_M_AK1
{}.
GetLength
(
I0
);
constexpr
index_t
A_K1
=
A0BlockDesc_AK0_M_AK1
{}.
GetLength
(
I2
);
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
return
transform_tensor_descriptor
(
...
...
@@ -227,8 +227,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
__host__
__device__
static
constexpr
auto
MakeB0BlockDescriptor_K0_L0_L1_L2_K1
(
const
B0BlockDesc_BK0_L_BK1
&
)
{
constexpr
index_t
B_K0
=
B0BlockDesc_BK0_L_BK1
{}.
GetLength
(
I0
);
constexpr
index_t
B_K1
=
B0BlockDesc_BK0_L_BK1
{}.
GetLength
(
I2
);
constexpr
index_t
B_K0
=
B0BlockDesc_BK0_L_BK1
{}.
GetLength
(
I0
);
constexpr
index_t
B_K1
=
B0BlockDesc_BK0_L_BK1
{}.
GetLength
(
I2
);
constexpr
index_t
LWaves
=
LPerBlock
/
(
LRepeat
*
LPerWmma
);
return
transform_tensor_descriptor
(
B0BlockDesc_BK0_L_BK1
{},
...
...
@@ -250,18 +250,18 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
return
transform_tensor_descriptor
(
A1BlockDesc_AL0_M_AL1
{},
make_tuple
(
make_pass_through_transform
(
Number
<
A_L0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
)),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
)),
make_pass_through_transform
(
Number
<
A_L1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
template
<
typename
B1BlockDesc_BL0_N_BL1
>
__host__
__device__
static
constexpr
auto
MakeB1BlockDescriptor_L0_N0_N1_N2_L1
(
const
B1BlockDesc_BL0_N_BL1
&
)
__host__
__device__
static
constexpr
auto
MakeB1BlockDescriptor_L0_N0_N1_N2_L1
(
const
B1BlockDesc_BL0_N_BL1
&
)
{
constexpr
index_t
B_K0
=
B1BlockDesc_BL0_N_BL1
{}.
GetLength
(
I0
);
constexpr
index_t
B_K1
=
B1BlockDesc_BL0_N_BL1
{}.
GetLength
(
I2
);
constexpr
index_t
B_K0
=
B1BlockDesc_BL0_N_BL1
{}.
GetLength
(
I0
);
constexpr
index_t
B_K1
=
B1BlockDesc_BL0_N_BL1
{}.
GetLength
(
I2
);
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
return
transform_tensor_descriptor
(
B1BlockDesc_BL0_N_BL1
{},
...
...
@@ -317,17 +317,18 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a_block_space_size_aligned
*
sizeof
(
FloatA
)
+
SharedMemTrait
::
b0_block_space_size_aligned
*
sizeof
(
FloatB0
));
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a_block_space_size_aligned
*
sizeof
(
FloatA
)
+
SharedMemTrait
::
b0_block_space_size_aligned
*
sizeof
(
FloatB0
));
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
FloatB1
);
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
FloatB1
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
sizeof
(
FloatAcc0
);
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
...
...
@@ -360,8 +361,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
L
%
LPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
N
%
NPerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
L
%
LPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
N
%
NPerBlock
==
0
))
{
return
false
;
}
...
...
@@ -432,7 +432,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
struct
SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
...
...
@@ -453,7 +453,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
b1_block_desc_bl0_n_bl1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b0_block_space_offset
=
a_block_space_size_aligned
.
value
;
static
constexpr
auto
b0_block_space_offset
=
a_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
0
;
// LDS allocation for reduction
...
...
@@ -466,10 +466,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
static
constexpr
auto
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
=
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
();
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.
GetElementSpaceSize
();
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.
GetElementSpaceSize
();
};
template
<
bool
HasMainKBlockLoop
,
typename
C0MatrixMask
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
C0MatrixMask
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatB0
*
__restrict__
p_b0_grid
,
const
FloatB1
*
__restrict__
p_b1_grid
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
8dbb73b1
...
...
@@ -165,7 +165,7 @@ __global__ void
static
constexpr
index_t
NumDTensor
=
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
::
Size
();
DsPointer
p_ds_grid_grp
;
static_for
<
0
,
NumDTensor
,
1
>
{}(
...
...
@@ -530,7 +530,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
template
<
typename
DsGridDesc_M_N_
>
__host__
__device__
static
constexpr
auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
DsGridDesc_M_N_
&
ds_grid_desc_m_n
)
{
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
[
i
]);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
8dbb73b1
...
...
@@ -141,10 +141,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_K0_M0_M1_M2_K1
(
const
ABlockDesc_AK0_M_AK1
&
)
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_K0_M0_M1_M2_K1
(
const
ABlockDesc_AK0_M_AK1
&
)
{
constexpr
index_t
A_K0
=
ABlockDesc_AK0_M_AK1
{}.
GetLength
(
I0
);
constexpr
index_t
A_K1
=
ABlockDesc_AK0_M_AK1
{}.
GetLength
(
I2
);
constexpr
index_t
A_K0
=
ABlockDesc_AK0_M_AK1
{}.
GetLength
(
I0
);
constexpr
index_t
A_K1
=
ABlockDesc_AK0_M_AK1
{}.
GetLength
(
I2
);
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
return
transform_tensor_descriptor
(
...
...
@@ -157,11 +158,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_K0_N0_N1_N2_K1
(
const
BBlockDesc_BK0_N_BK1
&
)
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_K0_N0_N1_N2_K1
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
B_K0
=
BBlockDesc_BK0_N_BK1
{}.
GetLength
(
I0
);
constexpr
index_t
B_K1
=
BBlockDesc_BK0_N_BK1
{}.
GetLength
(
I2
);
constexpr
index_t
B_K0
=
BBlockDesc_BK0_N_BK1
{}.
GetLength
(
I0
);
constexpr
index_t
B_K1
=
BBlockDesc_BK0_N_BK1
{}.
GetLength
(
I2
);
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
return
transform_tensor_descriptor
(
BBlockDesc_BK0_N_BK1
{},
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
8dbb73b1
...
...
@@ -1311,11 +1311,11 @@ template <typename SrcData,
typename
ElementwiseOperation
,
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
DstVectorDim
,
index_t
DstScalarPerVector
,
index_t
DstVectorDim
,
index_t
DstScalarPerVector
,
uint32_t
LowEightRowlaneIdx
,
uint32_t
HighEightRowLaneIdx
,
bool
IntraRowSwizzlePerm
,
bool
IntraRowSwizzlePerm
,
typename
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
...
...
@@ -1383,7 +1383,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// copy data from src_buf into dst_vector
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
// idx_md err. as dst access 2 strided elements while src visit 1 per loop
// idx_md err. as dst access 2 strided elements while src visit 1 per loop
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
...
...
@@ -1398,24 +1398,37 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
element_op_
(
v_this_row
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply intra-row swizzle permute
if
constexpr
(
IntraRowSwizzlePerm
){
// origin: 0xfedcba98, 0x76543210
temp
=
__builtin_amdgcn_permlane16
(
temp
,
type_convert
<
int
>
(
v_this_row
),
0xeca86420
,
0xfdb97531
,
1
,
0
);
if
constexpr
(
IntraRowSwizzlePerm
)
{
// origin:
// 0xfedcba98,
// 0x76543210
temp
=
__builtin_amdgcn_permlane16
(
temp
,
type_convert
<
int
>
(
v_this_row
),
0xeca86420
,
0xfdb97531
,
1
,
0
);
v_this_row
=
type_convert
<
float
>
(
temp
);
}
// apply inter-row permute.
temp
=
__builtin_amdgcn_permlanex16
(
temp
,
type_convert
<
int
>
(
v_this_row
),
LowEightRowlaneIdx
,
HighEightRowLaneIdx
,
1
,
0
);
temp
=
__builtin_amdgcn_permlanex16
(
temp
,
type_convert
<
int
>
(
v_this_row
),
LowEightRowlaneIdx
,
HighEightRowLaneIdx
,
1
,
0
);
v_theother_row
=
type_convert
<
float
>
(
temp
);
if
(
get_thread_local_1d_id
()
%
32
<
16
){
if
(
get_thread_local_1d_id
()
%
32
<
16
)
{
// apply type convert
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
v_this_row
);
dst_buf
(
Number
<
dst_offset
+
DstScalarPerVector
>
{})
=
type_convert
<
DstData
>
(
v_theother_row
);
dst_buf
(
Number
<
dst_offset
+
DstScalarPerVector
>
{})
=
type_convert
<
DstData
>
(
v_theother_row
);
}
else
{
else
{
// apply type convert
dst_buf
(
Number
<
dst_offset
+
DstScalarPerVector
>
{})
=
type_convert
<
DstData
>
(
v_this_row
);
dst_buf
(
Number
<
dst_offset
+
DstScalarPerVector
>
{})
=
type_convert
<
DstData
>
(
v_this_row
);
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
v_theother_row
);
}
});
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
8dbb73b1
...
...
@@ -444,7 +444,7 @@ struct WmmaGemm
make_pass_through_transform
(
MWave
),
make_pass_through_transform
(
Number
<
wmma_instr
.
num_thread_per_subgroups
>
{}),
make_pass_through_transform
(
NBlockxRepeat
),
make_pass_through_transform
(
NWave
),
make_pass_through_transform
(
NWave
),
make_unmerge_transform
(
make_tuple
(
Number
<
wmma_instr
.
num_subgroups
>
{},
Number
<
wmma_instr
.
num_acc_vgprs_per_wave
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
...
...
include/ck/utility/amd_wmma.hpp
View file @
8dbb73b1
...
...
@@ -24,10 +24,9 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
// * Inline assembly need to elimate the duplicated data load, compiler won't help you
// delete them.
// amd_assembly_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
}
};
...
...
script/clang-format-overwrite.sh
View file @
8dbb73b1
#
find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}'
git status
--porcelain
|
awk
'$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}'
| xargs
-n
1
-P
16
-I
{}
-t
sh
-c
'clang-format-10 -i -style=file {}'
find
.
-name
deps
-prune
-o
-name
build
-prune
-o
-iname
'*.h'
-o
-iname
'*.hpp'
-o
-iname
'*.cpp'
-o
-iname
'*.h.in'
-o
-iname
'*.hpp.in'
-o
-iname
'*.cpp.in'
-o
-iname
'*.cl'
-o
-iname
'*.cuh'
-o
-iname
'*.cu'
-o
-iname
'*.inc'
| xargs
-n
1
-P
16
-I
{}
-t
sh
-c
'clang-format-10 -i -style=file {}'
#
git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}'
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment