Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f0759faf
Commit
f0759faf
authored
Apr 26, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
20ddaeba
764164b4
Changes
103
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2765 additions
and
662 deletions
+2765
-662
example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp
...60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp
+273
-0
example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp
...ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp
+274
-0
include/ck/host_utility/flush_cache.hpp
include/ck/host_utility/flush_cache.hpp
+229
-0
include/ck/stream_config.hpp
include/ck/stream_config.hpp
+3
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
+6
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
+40
-17
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
+2
-2
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp
+16
-10
include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp
...or_operation/gpu/device/device_grouped_gemm_tile_loop.hpp
+128
-0
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
...evice/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
+9
-5
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp
...gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp
+520
-461
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
...operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
+196
-135
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
...device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
+787
-0
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+75
-9
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+49
-0
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+10
-0
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+9
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
+2
-12
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+122
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
...or_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
+15
-1
No files found.
example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp
0 → 100644
View file @
f0759faf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
I8
=
int8_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
A0DataType
=
BF16
;
using
AsDataType
=
ck
::
Tuple
<
A0DataType
>
;
using
B0DataType
=
I8
;
using
B1DataType
=
BF16
;
using
BsDataType
=
ck
::
Tuple
<
B0DataType
,
B1DataType
>
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
D0DataType
=
BF16
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
EDataType
=
BF16
;
using
A0Layout
=
Row
;
using
AsLayout
=
ck
::
Tuple
<
A0Layout
>
;
using
B0Layout
=
Row
;
using
B1Layout
=
B0Layout
;
using
BsLayout
=
ck
::
Tuple
<
B0Layout
,
B1Layout
>
;
using
D0Layout
=
Row
;
using
DsLayout
=
ck
::
Tuple
<>
;
using
ELayout
=
Row
;
using
Multiply
=
ck
::
tensor_operation
::
element_wise
::
Multiply
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
FastGelu
=
ck
::
tensor_operation
::
element_wise
::
FastGelu
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
Multiply
;
using
CDEElementOp
=
FastGelu
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleABD_Xdl_CShuffle
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
AsLayout
,
BsLayout
,
DsLayout
,
ELayout
,
AsDataType
,
BsDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
1
,
256
,
128
,
128
,
64
,
8
,
4
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v4
>
;
// clang-format on
int
main
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
// GEMM shape
ck
::
index_t
M
=
4096
;
ck
::
index_t
N
=
768
;
ck
::
index_t
K
=
6144
;
ck
::
index_t
StrideA
=
K
;
ck
::
index_t
StrideB
=
N
;
ck
::
index_t
StrideD
=
0
;
ck
::
index_t
StrideE
=
N
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
11
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
StrideA
=
std
::
stoi
(
argv
[
7
]);
StrideB
=
std
::
stoi
(
argv
[
8
]);
StrideD
=
std
::
stoi
(
argv
[
9
]);
StrideE
=
std
::
stoi
(
argv
[
10
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE
\n
"
);
exit
(
0
);
}
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
using
namespace
ck
::
literals
;
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
};
Tensor
<
A0DataType
>
a0_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
A0Layout
{}));
Tensor
<
B0DataType
>
b0_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
B0Layout
{}));
Tensor
<
B1DataType
>
b1_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
0
,
B1Layout
{}));
Tensor
<
D0DataType
>
d_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD
,
D0Layout
{}));
Tensor
<
EDataType
>
e_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
Tensor
<
EDataType
>
e_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
std
::
cout
<<
"a0_m_k: "
<<
a0_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_k_n: "
<<
b0_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b1_k_n: "
<<
b1_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_m_n: "
<<
d_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_m_n: "
<<
e_m_n_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
5
,
5
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
5
,
5
});
b1_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
0
,
5
});
d_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D0DataType
>
{
-
5
,
5
});
break
;
default:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
0.0
,
1.0
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
5
,
5
});
b1_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
B1DataType
>
{
0
,
5
});
d_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D0DataType
>
{
-
0.5
,
0.5
});
}
DeviceMem
a0_device_buf
(
sizeof
(
A0DataType
)
*
a0_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b0_device_buf
(
sizeof
(
B0DataType
)
*
b0_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b1_device_buf
(
sizeof
(
B1DataType
)
*
b1_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
D0DataType
)
*
d_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a0_device_buf
.
ToDevice
(
a0_m_k
.
mData
.
data
());
b0_device_buf
.
ToDevice
(
b0_k_n
.
mData
.
data
());
b1_device_buf
.
ToDevice
(
b1_k_n
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_m_n
.
mData
.
data
());
e_device_buf
.
ToDevice
(
e_m_n_device_result
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
constexpr
ck
::
index_t
NumATensor
=
1
;
constexpr
ck
::
index_t
NumBTensor
=
2
;
constexpr
ck
::
index_t
NumDTensor
=
0
;
// do GEMM
auto
device_op
=
DeviceOpInstance
{};
auto
invoker
=
device_op
.
MakeInvoker
();
auto
argument
=
device_op
.
MakeArgument
(
std
::
array
<
const
void
*
,
NumATensor
>
{
a0_device_buf
.
GetDeviceBuffer
()},
std
::
array
<
const
void
*
,
NumBTensor
>
{
b0_device_buf
.
GetDeviceBuffer
(),
b1_device_buf
.
GetDeviceBuffer
()},
std
::
array
<
const
void
*
,
NumDTensor
>
{},
e_device_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
std
::
array
<
ck
::
index_t
,
NumATensor
>
{
StrideA
},
std
::
array
<
ck
::
index_t
,
NumBTensor
>
{
StrideB
,
0
},
std
::
array
<
ck
::
index_t
,
NumDTensor
>
{},
StrideE
,
a_element_op
,
b_element_op
,
cde_element_op
);
if
(
!
device_op
.
IsSupportedArgument
(
argument
))
{
throw
std
::
runtime_error
(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
);
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
A0DataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
EDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
if
(
do_verification
)
{
Tensor
<
CShuffleDataType
>
c_m_n
({
M
,
N
});
Tensor
<
A0DataType
>
a_m_k
({
M
,
K
});
Tensor
<
B1DataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
B0Layout
{}));
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
b_element_op
(
b_k_n
(
k
,
n
),
b0_k_n
(
k
,
n
),
b1_k_n
(
k
,
n
));
}
}
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
A0DataType
,
B1DataType
,
CShuffleDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a0_m_k
,
b_k_n
,
c_m_n
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
cde_element_op
(
e_m_n_host_result
(
m
,
n
),
c_m_n
(
m
,
n
));
}
}
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
)
?
0
:
1
;
}
return
0
;
}
example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp
0 → 100644
View file @
f0759faf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
I8
=
int8_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
A0DataType
=
BF16
;
using
AsDataType
=
ck
::
Tuple
<
A0DataType
>
;
using
B0DataType
=
I8
;
using
B1DataType
=
BF16
;
using
BsDataType
=
ck
::
Tuple
<
B0DataType
>
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
D0DataType
=
BF16
;
using
DsDataType
=
ck
::
Tuple
<
B1DataType
,
D0DataType
>
;
using
EDataType
=
BF16
;
using
A0Layout
=
Row
;
using
AsLayout
=
ck
::
Tuple
<
A0Layout
>
;
using
B0Layout
=
Row
;
using
B1Layout
=
B0Layout
;
using
BsLayout
=
ck
::
Tuple
<
B0Layout
>
;
using
D0Layout
=
Row
;
using
DsLayout
=
ck
::
Tuple
<
B1Layout
,
D0Layout
>
;
using
ELayout
=
Row
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
MultiplyAddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
MultiplyAddFastGelu
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
MultiplyAddFastGelu
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleABD_Xdl_CShuffle
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
AsLayout
,
BsLayout
,
DsLayout
,
ELayout
,
AsDataType
,
BsDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
1
,
256
,
128
,
128
,
64
,
8
,
4
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v4
>
;
// clang-format on
int
main
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
// GEMM shape
ck
::
index_t
M
=
4096
;
ck
::
index_t
N
=
768
;
ck
::
index_t
K
=
6144
;
ck
::
index_t
StrideA
=
K
;
ck
::
index_t
StrideB
=
N
;
ck
::
index_t
StrideD
=
0
;
ck
::
index_t
StrideE
=
N
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
11
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
StrideA
=
std
::
stoi
(
argv
[
7
]);
StrideB
=
std
::
stoi
(
argv
[
8
]);
StrideD
=
std
::
stoi
(
argv
[
9
]);
StrideE
=
std
::
stoi
(
argv
[
10
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE
\n
"
);
exit
(
0
);
}
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
using
namespace
ck
::
literals
;
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
};
Tensor
<
A0DataType
>
a0_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
A0Layout
{}));
Tensor
<
B0DataType
>
b0_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
B0Layout
{}));
Tensor
<
B1DataType
>
b1_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
0
,
B1Layout
{}));
Tensor
<
D0DataType
>
d_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD
,
D0Layout
{}));
Tensor
<
EDataType
>
e_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
Tensor
<
EDataType
>
e_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
std
::
cout
<<
"a0_m_k: "
<<
a0_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_k_n: "
<<
b0_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b1_k_n: "
<<
b1_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_m_n: "
<<
d_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_m_n: "
<<
e_m_n_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
5
,
5
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
5
,
5
});
b1_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
0
,
5
});
d_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D0DataType
>
{
-
5
,
5
});
break
;
default:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
0.0
,
1.0
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
5
,
5
});
b1_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
B1DataType
>
{
0
,
5
});
d_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D0DataType
>
{
-
0.5
,
0.5
});
}
DeviceMem
a0_device_buf
(
sizeof
(
A0DataType
)
*
a0_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b0_device_buf
(
sizeof
(
B0DataType
)
*
b0_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b1_device_buf
(
sizeof
(
B1DataType
)
*
b1_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
D0DataType
)
*
d_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a0_device_buf
.
ToDevice
(
a0_m_k
.
mData
.
data
());
b0_device_buf
.
ToDevice
(
b0_k_n
.
mData
.
data
());
b1_device_buf
.
ToDevice
(
b1_k_n
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_m_n
.
mData
.
data
());
e_device_buf
.
ToDevice
(
e_m_n_device_result
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
constexpr
ck
::
index_t
NumATensor
=
1
;
constexpr
ck
::
index_t
NumBTensor
=
1
;
constexpr
ck
::
index_t
NumDTensor
=
2
;
// do GEMM
auto
device_op
=
DeviceOpInstance
{};
auto
invoker
=
device_op
.
MakeInvoker
();
auto
argument
=
device_op
.
MakeArgument
(
std
::
array
<
const
void
*
,
NumATensor
>
{
a0_device_buf
.
GetDeviceBuffer
()},
std
::
array
<
const
void
*
,
NumBTensor
>
{
b0_device_buf
.
GetDeviceBuffer
()},
std
::
array
<
const
void
*
,
NumDTensor
>
{
b1_device_buf
.
GetDeviceBuffer
(),
d_device_buf
.
GetDeviceBuffer
()},
e_device_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
std
::
array
<
ck
::
index_t
,
NumATensor
>
{
StrideA
},
std
::
array
<
ck
::
index_t
,
NumBTensor
>
{
StrideB
},
std
::
array
<
ck
::
index_t
,
NumDTensor
>
{
0
,
StrideD
},
StrideE
,
a_element_op
,
b_element_op
,
cde_element_op
);
if
(
!
device_op
.
IsSupportedArgument
(
argument
))
{
throw
std
::
runtime_error
(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
);
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
A0DataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
EDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
if
(
do_verification
)
{
Tensor
<
CShuffleDataType
>
c_m_n
({
M
,
N
});
Tensor
<
A0DataType
>
a_m_k
({
M
,
K
});
Tensor
<
B1DataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
B0Layout
{}));
#if 0
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K; ++k)
{
b_element_op(b_k_n(k, n), b0_k_n(k, n), b1_k_n(k, n));
}
}
#endif
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
A0DataType
,
B0DataType
,
CShuffleDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a0_m_k
,
b0_k_n
,
c_m_n
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
cde_element_op
(
e_m_n_host_result
(
m
,
n
),
c_m_n
(
m
,
n
),
b1_k_n
(
0
,
n
),
d_m_n
(
m
,
n
));
}
}
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
)
?
0
:
1
;
}
return
0
;
}
include/ck/host_utility/flush_cache.hpp
0 → 100644
View file @
f0759faf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <hip/hip_runtime.h>
#include <set>
#include "ck/ck.hpp"
#include "ck/stream_config.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/utility/flush_icache.hpp"
namespace
ck
{
namespace
utility
{
template
<
typename
Argument
>
struct
RotatingMemWrapper
{
using
ADataType
=
decltype
(
Argument
::
p_a_grid
);
using
BDataType
=
decltype
(
Argument
::
p_b_grid
);
RotatingMemWrapper
()
=
delete
;
RotatingMemWrapper
(
Argument
&
arg_
,
std
::
size_t
rotating_count_
,
std
::
size_t
size_a_
,
std
::
size_t
size_b_
)
:
arg
(
arg_
),
rotating_count
(
rotating_count_
),
size_a
(
size_a_
),
size_b
(
size_b_
)
{
p_a_grids
.
push_back
(
arg
.
p_a_grid
);
p_b_grids
.
push_back
(
arg
.
p_b_grid
);
for
(
size_t
i
=
1
;
i
<
rotating_count
;
i
++
)
{
{
void
*
pADeviceBuf
;
hip_check_error
(
hipMalloc
(
static_cast
<
void
**>
(
&
pADeviceBuf
),
size_a_
));
hip_check_error
(
hipMemcpy
(
static_cast
<
void
*>
(
pADeviceBuf
),
const_cast
<
void
*>
(
p_a_grids
[
0
]),
size_a_
,
hipMemcpyDeviceToDevice
));
p_a_grids
.
push_back
(
pADeviceBuf
);
}
{
void
*
pBDeviceBuf
;
hip_check_error
(
hipMalloc
(
static_cast
<
void
**>
(
&
pBDeviceBuf
),
size_b_
));
hip_check_error
(
hipMemcpy
(
static_cast
<
void
*>
(
pBDeviceBuf
),
const_cast
<
void
*>
(
p_b_grids
[
0
]),
size_b_
,
hipMemcpyDeviceToDevice
));
p_b_grids
.
push_back
(
pBDeviceBuf
);
}
}
}
void
Next
()
{
if
(
rotating_count
>
1
)
{
std
::
size_t
idx
=
iter
++
%
rotating_count
;
arg
.
p_a_grid
=
reinterpret_cast
<
ADataType
>
(
p_a_grids
[
idx
]);
arg
.
p_b_grid
=
reinterpret_cast
<
BDataType
>
(
p_b_grids
[
idx
]);
}
}
void
Print
()
{
std
::
cout
<<
"RotatingMemWrapper: { size_a: "
<<
size_a
<<
", size_b: "
<<
size_b
<<
", rotating_count: "
<<
rotating_count
<<
"}"
<<
std
::
endl
;
}
~
RotatingMemWrapper
()
{
if
(
rotating_count
>
1
)
{
// restore ptr
arg
.
p_a_grid
=
reinterpret_cast
<
ADataType
>
(
p_a_grids
[
0
]);
arg
.
p_b_grid
=
reinterpret_cast
<
BDataType
>
(
p_b_grids
[
0
]);
// free device mem
for
(
size_t
i
=
1
;
i
<
rotating_count
;
i
++
)
{
hip_check_error
(
hipFree
(
const_cast
<
void
*>
(
p_a_grids
[
i
])));
hip_check_error
(
hipFree
(
const_cast
<
void
*>
(
p_b_grids
[
i
])));
}
}
}
private:
Argument
&
arg
;
std
::
size_t
iter
=
0
;
std
::
size_t
rotating_count
=
1
;
std
::
size_t
size_a
=
0
;
std
::
size_t
size_b
=
0
;
std
::
vector
<
const
void
*>
p_a_grids
;
std
::
vector
<
const
void
*>
p_b_grids
;
};
inline
void
flush_icache
()
{
hipDeviceProp_t
deviceProps
;
hip_check_error
(
hipGetDeviceProperties
(
&
deviceProps
,
0
));
int32_t
gpu_block3
=
deviceProps
.
multiProcessorCount
*
60
;
ck
::
flush_icache
<<<
dim3
(
gpu_block3
),
dim3
(
64
),
0
,
nullptr
>>>
();
hip_check_error
(
hipGetLastError
());
}
// if TimePrePress == false, return time does not include preprocess's time
template
<
bool
TimePreprocess
,
typename
Args
,
typename
F
,
typename
PreProcessFunc
>
float
launch_and_time_kernel_with_preprocess
(
const
StreamConfig
&
stream_config
,
PreProcessFunc
preprocess
,
F
kernel
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
Args
&
args
)
{
#if CK_TIME_KERNEL
#define MEDIAN 1
if
(
stream_config
.
time_kernel_
)
{
#if DEBUG_LOG
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
__func__
,
grid_dim
.
x
,
grid_dim
.
y
,
grid_dim
.
z
,
block_dim
.
x
,
block_dim
.
y
,
block_dim
.
z
);
printf
(
"Warm up %d times
\n
"
,
stream_config
.
cold_niters_
);
#endif
// warm up
for
(
int
i
=
0
;
i
<
stream_config
.
cold_niters_
;
++
i
)
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
);
hip_check_error
(
hipGetLastError
());
}
const
int
nrepeat
=
stream_config
.
nrepeat_
;
if
(
nrepeat
==
0
)
{
return
0.0
;
}
#if DEBUG_LOG
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
#endif
#if MEDIAN
std
::
set
<
float
>
times
;
#else
float
total_time
=
0
;
#endif
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
if
constexpr
(
!
TimePreprocess
)
{
preprocess
();
}
hipEvent_t
start
,
stop
;
hip_check_error
(
hipEventCreate
(
&
start
));
hip_check_error
(
hipEventCreate
(
&
stop
));
hip_check_error
(
hipDeviceSynchronize
());
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
// calculate preprocess time
if
constexpr
(
TimePreprocess
)
{
preprocess
();
}
// run real kernel
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
);
hip_check_error
(
hipGetLastError
());
// end real kernel
hip_check_error
(
hipEventRecord
(
stop
,
stream_config
.
stream_id_
));
hip_check_error
(
hipEventSynchronize
(
stop
));
float
cur_time
=
0
;
hip_check_error
(
hipEventElapsedTime
(
&
cur_time
,
start
,
stop
));
#if MEDIAN
times
.
insert
(
cur_time
);
#else
total_time
+=
cur_time
;
#endif
#if DEBUG_LOG
std
::
cout
<<
"i: "
<<
i
<<
" cur_time: "
<<
cur_time
<<
std
::
endl
;
printf
(
"args.p_a_grid: %p, args.p_b_grid:%p
\n
"
,
static_cast
<
const
void
*>
(
args
.
p_a_grid
),
static_cast
<
const
void
*>
(
args
.
p_b_grid
));
#endif
}
#if MEDIAN
auto
mid
=
times
.
begin
();
std
::
advance
(
mid
,
(
nrepeat
-
1
)
/
2
);
if
(
nrepeat
%
2
==
1
)
{
return
*
mid
;
}
else
{
auto
mid_next
=
mid
;
std
::
advance
(
mid_next
,
1
);
return
(
*
mid
+
*
mid_next
)
/
2
;
}
#else
return
total_time
/
nrepeat
;
#endif
}
else
{
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
);
hip_check_error
(
hipGetLastError
());
return
0
;
}
#else
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
);
hip_check_error
(
hipGetLastError
());
return
0
;
#endif
}
}
// namespace utility
}
// namespace ck
include/ck/stream_config.hpp
View file @
f0759faf
...
...
@@ -13,4 +13,7 @@ struct StreamConfig
int
log_level_
=
0
;
int
cold_niters_
=
5
;
int
nrepeat_
=
50
;
bool
flush_cache
=
false
;
int
rotating_count
=
1
;
};
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
View file @
f0759faf
...
...
@@ -140,8 +140,10 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
using
Base
::
AMmaKStride
;
using
Base
::
BMmaKStride
;
static
constexpr
index_t
WgpPerCU
=
(
4
*
warpSize
/
BlockSize
)
>=
1
?
4
*
warpSize
/
BlockSize
:
1
;
static
constexpr
index_t
FullMemBandPrefetchStages
=
math
::
integer_divide_ceil
(
32768
/
(
4
*
warpSize
/
BlockSize
)
,
32768
/
WgpPerCU
,
(
MPerBlock
*
sizeof
(
ADataType
)
+
NPerBlock
*
sizeof
(
BDataType
))
*
KPerBlock
);
static
constexpr
index_t
PrefetchStages
=
FullMemBandPrefetchStages
>=
2
...
...
@@ -631,8 +633,10 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
static
constexpr
index_t
KPerInnerLoop
=
math
::
max
(
KPerThread
/
NumMacClusters
,
KPack
);
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPerInnerLoop
;
static
constexpr
index_t
WgpPerCU
=
(
4
*
warpSize
/
BlockSize
)
>=
1
?
4
*
warpSize
/
BlockSize
:
1
;
static
constexpr
index_t
FullMemBandPrefetchStages
=
math
::
integer_divide_ceil
(
32768
/
(
4
*
warpSize
/
BlockSize
)
,
32768
/
WgpPerCU
,
(
MPerBlock
*
sizeof
(
ADataType
)
+
NPerBlock
*
sizeof
(
BDataType
))
*
KPerBlock
);
static
constexpr
index_t
PrefetchStages
=
FullMemBandPrefetchStages
>=
2
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
View file @
f0759faf
...
...
@@ -184,19 +184,22 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
constexpr
auto
ds_read_b_issue_cycle
=
HotLoopInstList
::
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
8
:
4
;
constexpr
auto
ds_read_a_mfma_rate
=
(
mfma_cycle
-
8
+
ds_read_a_issue_cycle
-
1
)
/
ds_read_a_issue_cycle
;
(
mfma_cycle
-
4
+
2
*
ds_read_a_issue_cycle
-
1
)
/
(
2
*
ds_read_a_issue_cycle
)
;
constexpr
auto
ds_read_b_mfma_rate
=
(
mfma_cycle
-
8
+
ds_read_b_issue_cycle
-
1
)
/
ds_read_b_issue_cycle
;
(
mfma_cycle
-
4
+
2
*
ds_read_b_issue_cycle
-
1
)
/
(
2
*
ds_read_b_issue_cycle
);
constexpr
auto
num_dsread_a_mfma
=
(
num_ds_read_inst_a
+
ds_read_a_mfma_rate
-
1
)
/
ds_read_a_mfma_rate
;
constexpr
auto
num_dsread_b_mfma
=
(
num_ds_read_inst_b
+
ds_read_b_mfma_rate
-
1
)
/
ds_read_b_mfma_rate
;
// stage 1
// Separate this part?
constexpr
auto
num_mfma_per_ds_read
=
sizeof
(
ComputeDataType
)
/
sizeof
(
ADataType
)
>
sizeof
(
ComputeDataType
)
/
sizeof
(
BDataType
)
?
sizeof
(
ComputeDataType
)
/
sizeof
(
ADataType
)
:
sizeof
(
ComputeDataType
)
/
sizeof
(
BDataType
);
constexpr
auto
num_mfma_stage1
=
num_mfma_inst
-
num_mfma_per_ds_read
*
(
num_ds_read_inst_a
/
ds_read_a_mfma_rate
+
num_ds_read_inst_b
/
ds_read_b_mfma_rate
);
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// sizeof(ComputeDataType) / sizeof(BDataType)
// ? sizeof(ComputeDataType) / sizeof(ADataType)
// : sizeof(ComputeDataType) / sizeof(BDataType);
constexpr
auto
num_mfma_stage1
=
num_mfma_inst
-
(
num_dsread_a_mfma
+
num_dsread_b_mfma
);
constexpr
auto
num_mfma_per_issue
=
num_mfma_stage1
/
(
num_buffer_load_inst_a
+
num_buffer_load_inst_b
);
constexpr
auto
num_dswrite_per_issue_a
=
num_ds_write_inst_a
/
num_buffer_load_inst_a
;
...
...
@@ -226,16 +229,36 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
});
// stage 2
static_for
<
0
,
num_ds_read_inst_a
/
ds_read_a_mfma_rate
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
static_for
<
0
,
num_dsread_a_mfma
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
((
num_ds_read_inst_a
-
(
i
+
1
)
*
ds_read_a_mfma_rate
)
>=
ds_read_a_mfma_rate
)
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
ds_read_a_mfma_rate
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_ds_read
,
0
);
// MFMA
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst_a
-
(
num_dsread_a_mfma
-
1
)
*
ds_read_a_mfma_rate
,
0
);
// DS read
}
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
static_for
<
0
,
num_ds_read_inst_b
/
ds_read_b_mfma_rate
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
static_for
<
0
,
num_dsread_b_mfma
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
((
num_ds_read_inst_b
-
(
i
+
1
)
*
ds_read_b_mfma_rate
)
>=
ds_read_b_mfma_rate
)
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
ds_read_b_mfma_rate
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_ds_read
,
0
);
// MFMA
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst_b
-
(
num_dsread_b_mfma
-
1
)
*
ds_read_b_mfma_rate
,
0
);
// DS read
}
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
}
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
View file @
f0759faf
...
...
@@ -194,9 +194,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
constexpr
auto
ds_read_b_issue_cycle
=
HotLoopInstList
::
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
8
:
4
;
constexpr
auto
ds_read_a_mfma_rate
=
(
mfma_cycle
-
8
+
ds_read_a_issue_cycle
-
1
)
/
(
2
*
ds_read_a_issue_cycle
);
(
mfma_cycle
-
4
+
2
*
ds_read_a_issue_cycle
-
1
)
/
(
2
*
ds_read_a_issue_cycle
);
constexpr
auto
ds_read_b_mfma_rate
=
(
mfma_cycle
-
8
+
ds_read_b_issue_cycle
-
1
)
/
(
2
*
ds_read_b_issue_cycle
);
(
mfma_cycle
-
4
+
2
*
ds_read_b_issue_cycle
-
1
)
/
(
2
*
ds_read_b_issue_cycle
);
constexpr
auto
num_dsread_stage1_a
=
num_ds_read_inst_a
/
KRepeat
*
(
KRepeat
-
1
);
constexpr
auto
num_dsread_stage1_b
=
num_ds_read_inst_b
/
KRepeat
*
(
KRepeat
-
1
);
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp
View file @
f0759faf
...
...
@@ -41,7 +41,8 @@ template <typename ThreadGroup,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
,
typename
ThreadTransferSrcResetCoordinateAfterRunFlags
,
typename
ThreadTransferDstResetCoordinateAfterRunFlags
>
typename
ThreadTransferDstResetCoordinateAfterRunFlags
,
index_t
NumThreadScratch
=
1
>
struct
ThreadGroupTensorSliceTransfer_v7r2
{
static
constexpr
index_t
nDim
=
...
...
@@ -100,7 +101,7 @@ struct ThreadGroupTensorSliceTransfer_v7r2
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_i
d
()));
make_multi_index
(
ThreadGroup
::
GetThreadI
d
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
...
...
@@ -117,29 +118,33 @@ struct ThreadGroupTensorSliceTransfer_v7r2
}
}
template
<
typename
SrcBuffers
>
__device__
void
RunRead
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
)
template
<
typename
SrcBuffers
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunRead
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunRead
(
src_descs
,
src_bufs
);
threadwise_transfer_
.
RunRead
(
src_descs
,
src_bufs
,
thread_scratch_id
);
}
}
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
template
<
typename
DstBuffers
>
__device__
void
RunWrite
(
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
)
template
<
typename
DstBuffers
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunWrite
(
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
if
constexpr
(
is_detected
<
is_tuple
,
decltype
(
dst_bufs
)
>::
value
)
threadwise_transfer_
.
RunWrite
(
dst_descs
,
dst_bufs
);
threadwise_transfer_
.
RunWrite
(
dst_descs
,
dst_bufs
,
thread_scratch_id
);
else
threadwise_transfer_
.
RunWrite
(
dst_descs
,
tie
(
dst_bufs
));
threadwise_transfer_
.
RunWrite
(
dst_descs
,
tie
(
dst_bufs
)
,
thread_scratch_id
);
}
}
...
...
@@ -206,7 +211,8 @@ struct ThreadGroupTensorSliceTransfer_v7r2
SrcScalarPerVector
,
DstScalarPerVector
,
ThreadTransferSrcResetCoordinateAfterRunFlags
,
ThreadTransferDstResetCoordinateAfterRunFlags
>
;
ThreadTransferDstResetCoordinateAfterRunFlags
,
NumThreadScratch
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp
0 → 100644
View file @
f0759faf
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <iostream>
#include <vector>
#include <sstream>
#include "device_grouped_gemm.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
///
/// @brief Structure representing single GEMM problem arguments.
///
/// The pointer to the vector of those structures is passed to the GroupedGEMM entry
/// point kernel.
///
/// @tparam NumDTensor The number of D input tensors.
///
template
<
index_t
NumDTensor
=
0
>
struct
GroupedGemmTileLoopKernelArguments
{
__host__
__device__
GroupedGemmTileLoopKernelArguments
(
const
void
*
p_a_grid_
,
const
void
*
p_b_grid_
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid_
,
void
*
p_e_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs_
,
index_t
StrideE_
)
:
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_ds_grid
{
p_ds_grid_
},
p_e_grid
{
p_e_grid_
},
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideDs
{
StrideDs_
},
StrideE
{
StrideE_
}
{
}
const
void
*
p_a_grid
;
const
void
*
p_b_grid
;
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
;
void
*
p_e_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
index_t
StrideE
;
void
Print
()
const
{
std
::
stringstream
str
;
for
(
auto
sd
:
StrideDs
)
str
<<
sd
<<
","
;
std
::
cout
<<
"arg {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SE:"
<<
StrideE
<<
", "
<<
"SDs: {"
<<
str
.
str
()
<<
"}"
<<
"}"
<<
std
::
endl
;
}
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
struct
DeviceGroupedGemmTileLoop
:
public
DeviceGroupedGemm
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
//----------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
=
0
;
//----------------------------------------------------------------------------------------------
/// @brief Gets the device kernel argument size.
///
/// @param[in] p_arg The pointer to the Device op Argument.
///
/// @return The device kernel argument size.
///
virtual
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
View file @
f0759faf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -627,7 +627,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg
.
a_max_read_elems_
%
ABlockTransferSrcScalarPerVector
==
0
;
const
bool
valid_a_access_dim_m
=
ABlockTransferSrcVectorDim
==
1
&&
arg
.
a_mz_consecutive_
;
const
bool
valid_a_access_dim_k
=
ABlockTransferSrcVectorDim
==
2
&&
arg
.
a_kz_consecutive_
;
const
bool
valid_a_access_dim
=
valid_a_access_dim_m
||
valid_a_access_dim_k
;
const
bool
valid_a_access_dim
=
valid_a_access_dim_m
||
valid_a_access_dim_k
||
ABlockTransferSrcScalarPerVector
==
1
;
if
(
!
(
valid_a_vector_size
&&
valid_a_access_dim
))
{
return
false
;
...
...
@@ -637,7 +638,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg
.
b_max_read_elems_
%
BBlockTransferSrcScalarPerVector
==
0
;
const
bool
valid_b_access_dim_n
=
BBlockTransferSrcVectorDim
==
1
&&
arg
.
b_nz_consecutive_
;
const
bool
valid_b_access_dim_k
=
BBlockTransferSrcVectorDim
==
2
&&
arg
.
b_kz_consecutive_
;
const
bool
valid_b_access_dim
=
valid_b_access_dim_n
||
valid_b_access_dim_k
;
const
bool
valid_b_access_dim
=
valid_b_access_dim_n
||
valid_b_access_dim_k
||
BBlockTransferSrcScalarPerVector
==
1
;
if
(
!
(
valid_b_vector_size
&&
valid_b_access_dim
))
{
return
false
;
...
...
@@ -648,7 +650,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const
bool
valid_d_vector_size
=
arg
.
ds_max_read_elems_
[
i
]
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
// Vector read of Ds is always on N dimension.
const
bool
valid_d_access_dim
=
arg
.
ds_nz_consecutive_
[
i
];
const
bool
valid_d_access_dim
=
arg
.
ds_nz_consecutive_
[
i
]
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
;
if
(
!
(
valid_d_vector_size
&&
valid_d_access_dim
))
{
valid_ds_access
=
false
;
...
...
@@ -662,7 +665,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const
bool
valid_e_vector_size
=
arg
.
e_max_write_elems_
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
// Vector write of E is always on N dimension.
const
bool
valid_e_access_dim
=
arg
.
e_nz_consecutive_
;
const
bool
valid_e_access_dim
=
arg
.
e_nz_consecutive_
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
;
if
(
!
(
valid_e_vector_size
&&
valid_e_access_dim
))
{
return
false
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp
View file @
f0759faf
...
...
@@ -10,110 +10,30 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_
multiple_abd
.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_
v2
.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
AsPointer
,
typename
BsPointer
,
typename
DsPointer
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
AsGridDesc_AK0_M_AK1
,
typename
BsGridDesc_BK0_N_BK1
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2ETileMap
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_multiple_abd_xdl_cshuffle
(
AsPointer
p_as_grid
,
BsPointer
p_bs_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
AsGridDesc_AK0_M_AK1
as_grid_desc_ak0_m_ak1
,
const
BsGridDesc_BK0_N_BK1
bs_grid_desc_bk0_n_bk1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid
,
p_bs_grid
,
p_ds_grid
,
p_e_grid
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
as_grid_desc_ak0_m_ak1
,
bs_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
#else
ignore
=
p_as_grid
;
ignore
=
p_bs_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
as_grid_desc_ak0_m_ak1
;
ignore
=
bs_grid_desc_bk0_n_bk1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
block_2_etile_map
;
#endif
}
}
// namespace ck
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template
<
typename
AsLayout
,
typename
BsLayout
,
typename
DsLayout
,
typename
E
Layout
,
typename
C
Layout
,
typename
AsDataType
,
typename
BsDataType
,
typename
AccDataType
,
typename
Gemm
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
E
DataType
,
typename
C
DataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
C
DE
ElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
...
...
@@ -132,59 +52,56 @@ template <typename AsLayout,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
index_t
ABlockLdsExtraM
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
index_t
BBlockLdsExtraN
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
DeviceGemmMultipleABD_Xdl_CShuffle
:
public
DeviceGemmMultipleABD
<
AsLayout
,
BsLayout
,
DsLayout
,
E
Layout
,
C
Layout
,
AsDataType
,
BsDataType
,
DsDataType
,
E
DataType
,
C
DataType
,
AElementwiseOperation
,
BElementwiseOperation
,
C
DE
ElementwiseOperation
>
CElementwiseOperation
>
{
using
DeviceOp
=
DeviceGemmMultipleABD_Xdl_CShuffle
;
static
constexpr
index_t
NumATensor
=
AsDataType
::
Size
();
static
constexpr
index_t
NumBTensor
=
BsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
ComputeDataType
=
EDataType
;
using
ALayout
=
remove_cvref_t
<
tuple_element_t
<
0
,
AsLayout
>>
;
using
BLayout
=
remove_cvref_t
<
tuple_element_t
<
0
,
BsLayout
>>
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleABD_xdl_cshuffle
<
using
GridwiseGemm
=
GridwiseGemm_xdl_cshuffle_v3
<
ALayout
,
BLayout
,
CLayout
,
AsDataType
,
BsDataType
,
ComputeDataType
,
AccDataType
,
GemmAccDataType
,
CShuffleDataType
,
DsDataType
,
E
DataType
,
C
DataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
NumGemmKPrefetchStage
,
CElementwiseOperation
,
GemmSpec
,
BlockSize
,
MPerBlock
,
NPerBlock
,
...
...
@@ -213,360 +130,476 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
PipelineVer
>
;
// desc for problem definition
using
AsGridDesc_M_K
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
template
MakeAsGridDescriptor_M_K
<
AsLayout
,
GemmSpec
>(
{},
{},
{}))
>
;
using
BsGridDesc_N_K
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
template
MakeBsGridDescriptor_N_K
<
BsLayout
,
GemmSpec
>(
{},
{},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
template
MakeDsGridDescriptor_M_N
<
DsLayout
,
GemmSpec
>(
{},
{},
{}))
>
;
using
EGridDesc_M_N
=
decltype
(
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
1
,
1
,
1
));
// desc for blockwise copy
using
AsGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAsGridDescriptor_AK0_M_AK1
(
AsGridDesc_M_K
{}))
>
;
using
BsGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBsGridDescriptor_BK0_N_BK1
(
BsGridDesc_N_K
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
// block-to-e-tile map
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
std
::
array
<
const
void
*
,
NumATensor
>
p_as_grid
,
std
::
array
<
const
void
*
,
NumBTensor
>
p_bs_grid
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
void
*
p_e_grid
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
std
::
array
<
index_t
,
NumATensor
>
StrideAs
,
std
::
array
<
index_t
,
NumBTensor
>
StrideBs
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
:
p_as_grid_
{},
p_bs_grid_
{},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
as_grid_desc_m_k_
{},
bs_grid_desc_n_k_
{},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
MRaw
,
NRaw
,
StrideE
)},
as_grid_desc_ak0_m_ak1_
{},
bs_grid_desc_bk0_n_bk1_
{},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
MRaw_
{
MRaw
},
NRaw_
{
NRaw
},
KRaw_
{
KRaw
}
{
// populate pointer, desc for As
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
using
ALayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
AsLayout
>>
;
using
ADataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
AsDataType
>>
;
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
BlkGemmPipeSched
,
BlkGemmPipelineVer
,
ComputeTypeA
,
ComputeTypeB
>
;
// A pointer
p_as_grid_
(
i
)
=
static_cast
<
const
ADataType
*>
(
p_as_grid
[
i
]);
using
Argument
=
typename
GridwiseGemm
::
Argument
;
// A desc
as_grid_desc_m_k_
(
i
)
=
GridwiseGemm
::
template
MakeAGridDescriptor_M_K
<
ALayout
,
GemmSpec
>(
MRaw
,
KRaw
,
StrideAs
[
i
]);
});
// Invoker
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
arg
.
Print
();
}
// populate pointer, desc for Bs
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
using
BLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
BsLayout
>>
;
using
BDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
BsDataType
>>
;
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
)
;
}
// B pointer
p_bs_grid_
(
i
)
=
static_cast
<
const
BDataType
*>
(
p_bs_grid
[
i
]
);
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
M
,
arg
.
N
,
arg
.
KBatch
);
// B desc
bs_grid_desc_n_k_
(
i
)
=
GridwiseGemm
::
template
MakeBGridDescriptor_N_K
<
BLayout
,
GemmSpec
>(
NRaw
,
KRaw
,
StrideBs
[
i
]);
});
float
ave_time
=
0
;
// populate pointer, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
index_t
k_grain
=
arg
.
KBatch
*
KPerBlock
;
index_t
K_split
=
(
arg
.
K
+
k_grain
-
1
)
/
k_grain
*
KPerBlock
;
// D pointer
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds_grid
[
i
]);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
// D desc
ds_grid_desc_m_n_
(
i
)
=
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
DLayout
,
GemmSpec
>(
MRaw
,
NRaw
,
StrideDs
[
i
]);
});
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
if
(
arg
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg
.
p_c_grid
,
0
,
arg
.
M
*
arg
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
// populate desc for Ds/E
if
(
GridwiseGemm
::
CheckValidity
(
as_grid_desc_m_k_
,
bs_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
{
as_grid_desc_ak0_m_ak1_
=
GridwiseGemm
::
MakeDefaultAsGridDescriptor_AK0_M_AK1
(
as_grid_desc_m_k_
);
bs_grid_desc_bk0_n_bk1_
=
GridwiseGemm
::
MakeDefaultBsGridDescriptor_BK0_N_BK1
(
bs_grid_desc_n_k_
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
);
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
}
}
// private:
// pointers
typename
GridwiseGemm
::
AsGridPointer
p_as_grid_
;
typename
GridwiseGemm
::
BsGridPointer
p_bs_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
// tensor descriptors for problem definiton
AsGridDesc_M_K
as_grid_desc_m_k_
;
BsGridDesc_N_K
bs_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
// tensor descriptors for block/thread-wise copy
AsGridDesc_AK0_M_AK1
as_grid_desc_ak0_m_ak1_
;
BsGridDesc_BK0_N_BK1
bs_grid_desc_bk0_n_bk1_
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
// block-to-e-tile map
Block2ETileMap
block_2_etile_map_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// for checking vector load/store
index_t
MRaw_
;
index_t
NRaw_
;
index_t
KRaw_
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
);
};
// Invoker
struct
Invoker
:
public
BaseInvoker
constexpr
index_t
minimum_occupancy
=
BlkGemmPipeSched
==
BlockGemmPipelineScheduler
::
Intrawave
?
1
:
2
;
if
(
has_main_k_block_loop
)
{
// Tail number always full
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
||
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
#if 0
if(arg.KBatch > 1)
{
using
Argument
=
DeviceOp
::
Argument
;
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
#endif
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
// Tail number could be One to Seven
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
#if 0
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{}
)
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2
)
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
as_grid_desc_m_k_
,
arg
.
bs_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
const
auto
kernel
=
kernel_gemm_multiple_abd_xdl_cshuffle
<
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
typename
GridwiseGemm
::
AsGridPointer
,
typename
GridwiseGemm
::
BsGridPointer
,
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
DeviceOp
::
AsGridDesc_AK0_M_AK1
,
DeviceOp
::
BsGridDesc_BK0_N_BK1
,
DeviceOp
::
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
Block2ETileMap
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_as_grid_
,
arg
.
p_bs_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
as_grid_desc_ak0_m_ak1_
,
arg
.
bs_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_etile_map_
);
};
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
const
auto
K
=
arg
.
as_grid_desc_m_k_
[
I0
].
GetLength
(
I1
);
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
else
#endif
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
One
>
;
Run
(
kernel
);
}
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Full
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Full
>
;
Run
(
kernel
);
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Two
>
;
Run
(
kernel
);
}
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
3
)
{
if
(
!
ck
::
is_xdl_supported
())
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Three
)
{
return
false
;
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Three
>
;
Run
(
kernel
);
}
}
// check vector load/store
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Four
)
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Four
>
;
Run
(
kernel
);
}
}
bool
all_valid
=
true
;
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
5
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Five
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Five
>
;
Run
(
kernel
);
}
}
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
using
ALayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
AsLayout
>>
;
// check vector load of A
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Six
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
7
)
{
if
(
arg
.
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Seven
)
{
all_valid
=
false
;
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Seven
>
;
Run
(
kernel
);
}
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
}
// Tail number could be Odd or Even
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
#if 0
if(arg.KBatch > 1)
{
// FIXME: not rigorous
if
(
arg
.
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
all_valid
=
false
;
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
#endif
{
if
(
ABlockTransferSrcScalarPerVector
!=
1
)
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
all_valid
=
false
;
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
});
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
using
BLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
BsLayout
>>
;
// check vector laod of B
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
}
}
else
{
#if 0
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
if
(
arg
.
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
all_valid
=
false
;
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
else
#endif
{
// FIXME: not rigorous
if
(
arg
.
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
all_valid
=
false
;
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
}
else
{
if
(
BBlockTransferSrcScalarPerVector
!=
1
)
// Tail number always 1
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
#if 0
if(arg.KBatch > 1)
{
all_valid
=
false
;
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
#endif
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
}
});
// check vector load of Ds
// only support RowMajor for now
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
return
ave_time
;
}
if
constexpr
(
!
is_same_v
<
DLayout
,
Row
>
)
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
all_valid
=
false
;
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
)
;
}
}
)
;
};
// check vector store of E
// only support RowMajor for now
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
{
if
(
arg
.
NRaw_
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
static
constexpr
bool
IsValidCompilationParameter
()
{
all_valid
=
false
;
// TODO: properly implement this check
return
true
;
}
}
else
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
all_valid
=
false
;
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
(
!
all_valid
)
if
((
arg
.
K
%
AK1
!=
0
||
arg
.
K
%
BK1
!=
0
)
&&
!
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
GemmSpec
==
GemmSpecialization
::
KPadding
))
{
return
false
;
}
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
as_grid_desc_m_k_
,
arg
.
bs_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
);
return
GridwiseGemm
::
CheckValidity
(
arg
);
}
// polymorphic
...
...
@@ -588,8 +621,27 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
C
DE
ElementwiseOperation
c
de
_element_op
)
CElementwiseOperation
c_element_op
)
{
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
using
ALayout_
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
AsLayout
>>
;
static_assert
(
is_same
<
ALayout_
,
ALayout
>::
value
,
""
);
});
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
using
BLayout_
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
BsLayout
>>
;
static_assert
(
is_same
<
BLayout_
,
BLayout
>::
value
,
""
);
});
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout_
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
static_assert
(
is_same
<
DLayout_
,
CLayout
>::
value
,
""
);
});
return
Argument
{
p_as
,
p_bs
,
p_ds
,
...
...
@@ -601,16 +653,16 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
StrideBs
,
StrideDs
,
StrideE
,
1
,
a_element_op
,
b_element_op
,
c
de
_element_op
};
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
array
<
const
void
*
,
NumATensor
>
p_as
,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
array
<
const
void
*
,
NumATensor
>
p_as
,
std
::
array
<
const
void
*
,
NumBTensor
>
p_bs
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
...
...
@@ -623,7 +675,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
C
DE
ElementwiseOperation
c
de
_element_op
)
override
CElementwiseOperation
c_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_as
,
p_bs
,
...
...
@@ -636,9 +688,10 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
StrideBs
,
StrideDs
,
StrideE
,
1
,
a_element_op
,
b_element_op
,
c
de
_element_op
);
c_element_op
);
}
// polymorphic
...
...
@@ -652,35 +705,41 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
BlockGemmPipelineScheduler
,
std
::
string
>
BlkGemmPipelineSchedulerToString
{
{
BlockGemmPipelineScheduler
::
Intrawave
,
"Intrawave"
},
{
BlockGemmPipelineScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
std
::
map
<
BlockGemmPipelineVersion
,
std
::
string
>
BlkGemmPipelineVersionToString
{
{
BlockGemmPipelineVersion
::
v1
,
"v1"
},
{
BlockGemmPipelineVersion
::
v2
,
"v2"
},
{
BlockGemmPipelineVersion
::
v3
,
"v3"
},
{
BlockGemmPipelineVersion
::
v4
,
"v4"
},
{
BlockGemmPipelineVersion
::
v5
,
"v5"
}};
// clang-format off
str
<<
"DeviceGemm
MultipleABD_Xdl_CShuffle
"
str
<<
"DeviceGemm
XdlUniversal
"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
std
::
string
(
ALayout
::
name
)[
0
]
<<
std
::
string
(
BLayout
::
name
)[
0
]
<<
std
::
string
(
CLayout
::
name
)[
0
]
<<
">"
<<
" LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
<<
" BlkSize: "
<<
BlockSize
<<
", "
<<
"BlkTile: "
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerBlock
<<
", "
<<
"WaveTile: "
<<
MPerXDL
<<
"x"
<<
NPerXDL
<<
", "
<<
"WaveMap: "
<<
MXdlPerWave
<<
"x"
<<
NXdlPerWave
<<
", "
<<
"VmemReadVec: "
<<
ABlockTransferSrcScalarPerVector
<<
"x"
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
"BlkGemmPipelineScheduler: "
<<
BlkGemmPipelineSchedulerToString
[
BlkGemmPipeSched
]
<<
", "
<<
"BlkGemmPipelineVersion: "
<<
BlkGemmPipelineVersionToString
[
BlkGemmPipelineVer
]
<<
", "
<<
"BlkGemmPipelinePrefetchStages: "
<<
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
;
// clang-format on
return
str
.
str
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
View file @
f0759faf
...
...
@@ -15,6 +15,7 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -151,14 +152,56 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
if
(
stream_config
.
flush_cache
)
{
Argument
arg_
=
arg
;
ck
::
utility
::
RotatingMemWrapper
<
Argument
>
rotating_mem
(
arg_
,
stream_config
.
rotating_count
,
arg_
.
M
*
arg_
.
K
*
sizeof
(
ADataType
),
arg_
.
K
*
arg_
.
N
*
sizeof
(
BDataType
));
rotating_mem
.
Print
();
auto
run_flush_cache
=
[
&
]()
{
// flush icache
ck
::
utility
::
flush_icache
();
// rotating mem
rotating_mem
.
Next
();
// clear c mem
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
if
(
arg_
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg_
.
p_c_grid
,
0
,
arg_
.
M
*
arg_
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
}
};
ave_time
=
ck
::
utility
::
launch_and_time_kernel_with_preprocess
<
false
>
(
stream_config
,
run_flush_cache
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg_
);
}
else
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
if
(
arg
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg
.
p_c_grid
,
0
,
arg
.
M
*
arg
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
}
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
);
}
};
constexpr
index_t
minimum_occupancy
=
...
...
@@ -171,6 +214,8 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
...
...
@@ -179,6 +224,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
else
{
const
auto
kernel
=
...
...
@@ -193,11 +239,13 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
...
...
@@ -207,8 +255,8 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Full
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
...
...
@@ -218,7 +266,8 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
...
...
@@ -277,7 +326,8 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
...
...
@@ -304,6 +354,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
}
}
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
...
...
@@ -421,6 +472,8 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
...
...
@@ -443,6 +496,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
Run
(
kernel
);
}
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
...
...
@@ -470,11 +524,13 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
else
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
...
...
@@ -483,8 +539,8 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
...
...
@@ -492,6 +548,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
Run
(
kernel
);
}
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
...
...
@@ -522,7 +579,10 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// Tail number always 1
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
if
(
arg
.
KBatch
>
1
)
{
if
constexpr
(
!
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
...
...
@@ -531,6 +591,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
else
{
const
auto
kernel
=
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
0 → 100644
View file @
f0759faf
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <tuple>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/host_utility/stream_utility.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
///
/// @brief Entry point kernel for device-wide Grouped GEMM operation.
///
/// @param[in] gemm_descs_const The pointer to the array of GEMM descriptor structures.
/// @param[in] group_count The number of together processed GEMMs.
///
/// @tparam GridwiseGemm The specific GridwiseGEMM algorithm implementation.
/// @tparam GemmDesc The structure holding all necessary descriptors and
/// other data needed for grouped gemm calculation and work
/// distribution.
/// @tparam LocalBlock2ETileMap The structure providing mapping between workgroup ids,
/// the data tiles to process and the output tiles.
///
template
<
typename
GridwiseGemm
,
typename
GemmDesc
,
GemmSpecialization
GemmSpec
,
typename
DsDataType
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
OffsettedBlockToCTileMap
,
typename
LocalBlock2ETileMap
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_multiple_d_xdl
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
constexpr
auto
NumDTensor
=
DsDataType
::
Size
();
index_t
tile_id
=
get_block_1d_id
();
index_t
tile_offset
=
0
;
index_t
group_id
=
-
1
;
index_t
group_offset
=
0
;
index_t
grid_size_grp
=
0
;
index_t
gemm_tile_id_start
=
0
;
index_t
gemm_tile_id_end
=
0
;
using
AGridDescMK
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
template
MakeAGridDescriptor_M_K
<
ALayout
,
GemmSpec
>(
1
,
1
,
1
))
>
;
using
BGridDescNK
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
template
MakeBGridDescriptor_N_K
<
BLayout
,
GemmSpec
>(
1
,
1
,
1
))
>
;
using
EGridDescMN
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
1
,
1
,
1
))
>
;
using
DsGridDescMN
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
template
MakeDsGridDescriptor_M_N
<
DsLayout
,
GemmSpec
>(
{},
{},
{}))
>
;
index_t
M
=
0
,
N
=
0
,
K
=
0
;
index_t
StrideA
,
StrideB
,
StrideE
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
AGridDescMK
a_grid_desc_mk
;
BGridDescNK
b_grid_desc_nk
;
EGridDescMN
e_grid_desc_mn
;
DsGridDescMN
ds_grid_desc_mn
;
auto
b2c_tile_map
=
OffsettedBlockToCTileMap
(
LocalBlock2ETileMap
(
1
,
1
),
1
,
1
);
do
{
// Find corresponding GEMM group for our tile
while
(
!
(
tile_id
>=
gemm_tile_id_start
&&
tile_id
<
gemm_tile_id_end
)
&&
group_id
<
group_count
)
{
group_offset
+=
grid_size_grp
;
group_id
++
;
if
(
group_id
>=
group_count
)
return
;
M
=
gemm_desc_ptr
[
group_id
].
M
;
N
=
gemm_desc_ptr
[
group_id
].
N
;
K
=
gemm_desc_ptr
[
group_id
].
K
;
if
(
M
*
N
*
K
==
0
)
{
grid_size_grp
=
0
;
continue
;
}
b2c_tile_map
=
OffsettedBlockToCTileMap
(
LocalBlock2ETileMap
(
M
,
N
),
group_offset
,
tile_offset
);
grid_size_grp
=
b2c_tile_map
.
CalculateGridSize
(
M
,
N
);
gemm_tile_id_start
=
group_offset
;
gemm_tile_id_end
=
group_offset
+
grid_size_grp
;
}
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
StrideB
=
gemm_desc_ptr
[
group_id
].
StrideB
;
StrideDs
=
gemm_desc_ptr
[
group_id
].
StrideDs
;
StrideE
=
gemm_desc_ptr
[
group_id
].
StrideE
;
a_grid_desc_mk
=
GridwiseGemm
::
template
MakeAGridDescriptor_M_K
<
ALayout
,
GemmSpec
>(
M
,
K
,
StrideA
);
b_grid_desc_nk
=
GridwiseGemm
::
template
MakeBGridDescriptor_N_K
<
BLayout
,
GemmSpec
>(
K
,
N
,
StrideB
);
e_grid_desc_mn
=
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
M
,
N
,
StrideE
);
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
j
.
value
,
DsLayout
>>
;
ds_grid_desc_mn
(
j
)
=
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
DLayout
,
GemmSpec
>(
M
,
N
,
StrideDs
[
j
]);
});
using
DsGridPointer
=
decltype
(
GridwiseGemm
::
MakeDsGridPointer
());
DsGridPointer
p_ds_grid
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
p_ds_grid
(
i
)
=
static_cast
<
const
DDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_ds_grid
[
i
]);
});
bool
has_main_kblock_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_mk
.
GetLength
(
Number
<
1
>
{}));
// Update tile offset if we have moved within group
b2c_tile_map
.
UpdateTileOffset
(
tile_offset
);
if
(
has_main_kblock_loop
)
{
GridwiseGemm
::
template
Run
<
true
>(
gemm_desc_ptr
[
group_id
].
p_a_grid
,
gemm_desc_ptr
[
group_id
].
p_b_grid
,
p_ds_grid
,
gemm_desc_ptr
[
group_id
].
p_e_grid
,
static_cast
<
void
*>
(
p_shared
),
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_mk
,
b_grid_desc_nk
,
ds_grid_desc_mn
,
e_grid_desc_mn
,
b2c_tile_map
);
}
else
{
GridwiseGemm
::
template
Run
<
false
>(
gemm_desc_ptr
[
group_id
].
p_a_grid
,
gemm_desc_ptr
[
group_id
].
p_b_grid
,
p_ds_grid
,
gemm_desc_ptr
[
group_id
].
p_e_grid
,
static_cast
<
void
*>
(
p_shared
),
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_mk
,
b_grid_desc_nk
,
ds_grid_desc_mn
,
e_grid_desc_mn
,
b2c_tile_map
);
}
tile_id
+=
get_grid_size
();
tile_offset
+=
get_grid_size
();
}
while
(
group_id
<
group_count
);
#else
ignore
=
gemm_descs_const
;
ignore
=
group_count
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
ck
::
index_t
NumGemmKPrefetchStage
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
AK1
,
ck
::
index_t
BK1
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
typename
ComputeDataType
=
EDataType
>
struct
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
:
public
DeviceGroupedGemmTileLoop
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
using
DeviceOp
=
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
PipelineVer
>
;
template
<
typename
UnderlyingBlockToCTileMap
>
struct
OffsettedBlockToCTileMap
{
using
underlying_type
=
UnderlyingBlockToCTileMap
;
__host__
__device__
OffsettedBlockToCTileMap
(
UnderlyingBlockToCTileMap
block_to_ctile_map
,
index_t
group_offset
,
index_t
tile_offset
)
:
block_to_ctile_map_
{
block_to_ctile_map
},
group_offset_
{
group_offset
},
tile_offset_
{
tile_offset
}
{
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
return
block_to_ctile_map_
.
CalculateBottomIndex
(
make_multi_index
(
idx_top
[
Number
<
0
>
{}]
+
tile_offset_
-
group_offset_
));
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
c_tile_idx
,
const
CTileDim
&
c_tile_dim
)
const
{
return
block_to_ctile_map_
.
ValidCTileIndex
(
c_tile_idx
,
c_tile_dim
);
}
template
<
typename
CGridDesc_M_N
>
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
return
block_to_ctile_map_
.
CheckValidity
(
c_grid_desc_m_n
);
}
__host__
__device__
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
const
{
return
block_to_ctile_map_
.
CalculateGridSize
(
M
,
N
);
}
__device__
void
UpdateTileOffset
(
index_t
offset
)
{
tile_offset_
=
offset
;
}
UnderlyingBlockToCTileMap
block_to_ctile_map_
;
index_t
group_offset_
;
index_t
tile_offset_
;
};
using
KernelArguments
=
GroupedGemmTileLoopKernelArguments
<
NumDTensor
>
;
using
Block2ETileMap
=
BlockToCTileMap_N00_M0_N01Adapt
<
MPerBlock
,
NPerBlock
>
;
using
OffsetedLocalBlock2ETileMap
=
OffsettedBlockToCTileMap
<
Block2ETileMap
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
std
::
vector
<
const
void
*>&
/* p_As */
,
std
::
vector
<
const
void
*>&
/* p_Bs */
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
/* p_Ds */
,
std
::
vector
<
void
*>&
/* p_Es */
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
,
int
occupancy_num_blocks
,
int
gpu_cu_count
)
:
group_count_
{
static_cast
<
index_t
>
(
gemm_descs
.
size
())},
occupancy_num_blocks_
{
occupancy_num_blocks
},
gpu_cu_count_
{
gpu_cu_count
},
gemm_descs_
{
gemm_descs
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
tile_count_
{
0
}
{
for
(
const
auto
&
desc
:
gemm_descs
)
{
const
auto
M
=
desc
.
M_
;
const
auto
N
=
desc
.
N_
;
const
auto
b2c_tile_map
=
Block2ETileMap
(
M
,
N
);
tile_count_
+=
b2c_tile_map
.
CalculateGridSize
(
M
,
N
);
}
}
index_t
group_count_
;
const
void
*
p_dev_gemm_args_
;
int
occupancy_num_blocks_
;
int
gpu_cu_count_
;
const
std
::
vector
<
GemmDesc
>&
gemm_descs_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
index_t
tile_count_
;
};
struct
KernelConfig
{
// The oversubscription factor for the number of blocks that can simultaneously reside on
// GPU.
static
constexpr
int
BLOCK_SUBSCRIPTION_FACTOR
=
1
;
static
constexpr
int
BLOCK_WAVES
=
BlockSize
/
get_warp_size
();
static
constexpr
int
CU_SIMDS
=
4
;
// Assume we want to have at most 2 waves per SIMD
static
constexpr
int
CU_BLOCKS
=
math
::
integer_divide_floor
(
2
*
CU_SIMDS
,
BLOCK_WAVES
);
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
///
/// @brief Launch Grouped Gemm kernel.
///
/// @note This function overload is using user provided device buffer for kernel
/// arguments.
///
/// @param[in] arg The structure containing kernel arguments (in host
/// memory).
/// @param[in] dev_gemm_args The pointer to device memory with kernel arguments.
/// @param[in] stream_config The device stream configuration.
///
/// @return The average kernel execution time (if time measurement is enabled.)
///
float
Run
(
const
Argument
&
arg
,
const
void
*
dev_gemm_args
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
dev_gemm_args
==
nullptr
)
{
std
::
ostringstream
err
;
err
<<
"The gemm arguments device buffer is not allocated!"
<<
" In "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
float
ave_time
=
0
;
ave_time
=
DispatchKernel
(
arg
,
dev_gemm_args
,
stream_config
);
return
ave_time
;
}
///
/// @brief Launch Grouped Gemm kernel.
///
/// @note This function overload is using device buffers (for kernel arguments and
/// for kernel auxiliary workspace) provided with an argument. The user should
/// call @see GetDeviceKernelArgSize, and @see SetDeviceKernelArgs, on arg
/// parameter to properly allocate those buffers.
///
/// @param[in] arg The structure containing kernel arguments (in host memory).
/// @param[in] stream_config The device stream configuration.
///
/// @return The average kernel execution time (if time measurement is enabled.)
///
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
arg
.
p_dev_gemm_args_
==
nullptr
)
{
std
::
ostringstream
err
;
err
<<
"The gemm arguments device buffer is not allocated!"
<<
" In "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
return
Run
(
arg
,
arg
.
p_dev_gemm_args_
,
stream_config
);
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
private:
float
DispatchKernel
(
const
Argument
&
arg
,
const
void
*
dev_gemm_args
,
const
StreamConfig
&
stream_config
)
const
{
const
auto
kernel
=
kernel_grouped_gemm_multiple_d_xdl
<
GridwiseGemm
,
KernelArguments
,
GemmSpec
,
DsDataType
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
OffsetedLocalBlock2ETileMap
,
Block2ETileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
;
return
LaunchKernel
(
kernel
,
arg
,
dev_gemm_args
,
stream_config
);
}
template
<
typename
KernelFunction
>
int
CalculateMaxOccupancyGridSize
(
const
KernelFunction
&
kernel
,
const
StreamConfig
&
stream_config
)
const
{
// Calculate max number of workgroups that can simultaneously reside on the CU.
int
occ_num_blocks
=
0
;
size_t
dyn_shared_mem_per_blk
=
0
;
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occ_num_blocks
,
kernel
,
BlockSize
,
dyn_shared_mem_per_blk
));
int
cu_count
=
getAvailableComputeUnitCount
(
stream_config
);
if
(
stream_config
.
log_level_
>
0
)
{
std
::
cout
<<
"MaxActiveBlocksPerCU: "
<<
occ_num_blocks
<<
", available CUs count: "
<<
cu_count
<<
", occup. grid size: "
<<
ck
::
math
::
min
(
occ_num_blocks
,
KernelConfig
::
CU_BLOCKS
)
*
cu_count
<<
std
::
endl
;
}
return
cu_count
*
ck
::
math
::
min
(
occ_num_blocks
,
KernelConfig
::
CU_BLOCKS
);
}
template
<
typename
KernelFunction
>
float
LaunchKernel
(
const
KernelFunction
&
kernel
,
const
Argument
&
arg
,
const
void
*
dev_gemm_args
,
const
StreamConfig
&
stream_config
)
const
{
int
grid_size
=
CalculateMaxOccupancyGridSize
(
kernel
,
stream_config
);
if
(
stream_config
.
log_level_
>
0
)
{
std
::
cout
<<
"grid_size: "
<<
grid_size
<<
" tile_count: "
<<
arg
.
tile_count_
<<
std
::
endl
;
}
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
dev_gemm_args
),
arg
.
group_count_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
using
DsGridDescMN
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
template
MakeDsGridDescriptor_M_N
<
DsLayout
,
GemmSpec
>(
{},
{},
{}))
>
;
bool
supported
=
true
;
for
(
const
auto
&
gdesc
:
arg
.
gemm_descs_
)
{
const
auto
M
=
gdesc
.
M_
;
const
auto
N
=
gdesc
.
N_
;
const
auto
K
=
gdesc
.
K_
;
const
auto
StrideA
=
gdesc
.
stride_A_
;
const
auto
StrideB
=
gdesc
.
stride_B_
;
const
auto
StrideE
=
gdesc
.
stride_C_
;
const
auto
&
StrideDs
=
gdesc
.
stride_Ds_
;
// If M dimension is unknown at launch time then validate just NK.
// If N or K dim is zero (or unknown) then the vector loads responsibility lies on
// the user.
if
(
N
*
K
==
0
)
continue
;
const
auto
a_grid_desc_mk
=
GridwiseGemm
::
template
MakeAGridDescriptor_M_K
<
ALayout
,
GemmSpec
>(
M
,
K
,
StrideA
);
const
auto
b_grid_desc_nk
=
GridwiseGemm
::
template
MakeBGridDescriptor_N_K
<
BLayout
,
GemmSpec
>(
K
,
N
,
StrideB
);
const
auto
e_grid_desc_mn
=
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
M
,
N
,
StrideE
);
DsGridDescMN
ds_grid_desc_mn
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
j
.
value
,
DsLayout
>>
;
ds_grid_desc_mn
(
j
)
=
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
DLayout
,
GemmSpec
>(
M
,
N
,
StrideDs
[
j
]);
});
const
auto
b2c_tile_map
=
Block2ETileMap
(
M
,
N
);
if
(
!
(
GridwiseGemm
::
template
CheckValidity
(
a_grid_desc_mk
,
b_grid_desc_nk
,
ds_grid_desc_mn
,
e_grid_desc_mn
,
b2c_tile_map
)
&&
GridwiseGemm
::
template
CheckTensorTransfersValidity
<
ALayout
,
BLayout
,
ELayout
>(
M
,
N
,
K
)))
{
#if DEBUG_LOG
std
::
cout
<<
"The provided GEMM problem size (M,N,K) ["
<<
M
<<
","
<<
N
<<
","
<<
K
<<
"] are not supported by current template parameters!"
<<
" In "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
#endif
supported
=
false
;
}
}
return
supported
;
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_Ds
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>
gemm_descs
,
AElementwiseOperation
a_elementwise_op
,
BElementwiseOperation
b_elementwise_op
,
CDEElementwiseOperation
cde_elementwise_op
)
{
const
auto
kernel
=
kernel_grouped_gemm_multiple_d_xdl
<
GridwiseGemm
,
KernelArguments
,
GemmSpec
,
DsDataType
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
OffsetedLocalBlock2ETileMap
,
Block2ETileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
;
int
occupancy
,
num_cu
;
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
0
));
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
hip_check_error
(
hipGetDevice
(
&
dev
));
hip_check_error
(
hipGetDeviceProperties
(
&
dev_prop
,
dev
));
num_cu
=
dev_prop
.
multiProcessorCount
;
return
Argument
{
p_As
,
p_Bs
,
p_Ds
,
p_Es
,
gemm_descs
,
a_elementwise_op
,
b_elementwise_op
,
cde_elementwise_op
,
occupancy
,
num_cu
};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_Ds
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
AElementwiseOperation
a_elementwise_op
,
BElementwiseOperation
b_elementwise_op
,
CDEElementwiseOperation
cde_elementwise_op
)
override
{
const
auto
kernel
=
kernel_grouped_gemm_multiple_d_xdl
<
GridwiseGemm
,
KernelArguments
,
GemmSpec
,
DsDataType
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
OffsetedLocalBlock2ETileMap
,
Block2ETileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
;
int
occupancy
,
num_cu
;
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
0
));
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
hip_check_error
(
hipGetDevice
(
&
dev
));
hip_check_error
(
hipGetDeviceProperties
(
&
dev_prop
,
dev
));
num_cu
=
dev_prop
.
multiProcessorCount
;
return
std
::
make_unique
<
Argument
>
(
p_As
,
p_Bs
,
p_Ds
,
p_Es
,
gemm_descs
,
a_elementwise_op
,
b_elementwise_op
,
cde_elementwise_op
,
occupancy
,
num_cu
);
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
ostringstream
();
// clang-format off
str
<<
"DeviceGroupedGemmMultipleDXdlCShuffleTileLoop"
<<
"<"
<<
std
::
string
(
ALayout
::
name
)[
0
]
<<
","
<<
std
::
string
(
BLayout
::
name
)[
0
]
<<
","
<<
std
::
string
(
ELayout
::
name
)[
0
]
<<
","
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
PipelineVer
<<
", "
<<
LoopSched
<<
">"
;
// clang-format on
return
str
.
str
();
}
void
SetDeviceKernelArgs
(
Argument
&
arg
,
void
*
p_dev_kernel_args
)
const
{
arg
.
p_dev_gemm_args_
=
p_dev_kernel_args
;
}
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
override
{
return
SetDeviceKernelArgs
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
p_dev_kernel_args
);
}
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
override
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
group_count_
*
sizeof
(
KernelArguments
);
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
f0759faf
...
...
@@ -92,15 +92,6 @@ struct Add
};
};
struct
Scales
{
template
<
typename
Y
,
typename
X0
,
typename
X1
>
__host__
__device__
constexpr
void
operator
()(
Y
&
y
,
const
X0
&
x0
,
const
X1
&
x1
)
const
{
y
=
ck
::
type_convert
<
Y
>
(
ck
::
type_convert
<
float
>
(
x0
)
*
ck
::
type_convert
<
float
>
(
x1
));
}
};
struct
Max
{
template
<
typename
Y
,
typename
X0
,
typename
X1
>
...
...
@@ -188,6 +179,16 @@ struct Multiply
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_tmp
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
>
(
bhalf_t
&
y
,
const
int8_t
&
x0
,
const
bhalf_t
&
x1
)
const
{
const
float
x1_tmp
=
ck
::
type_convert
<
float
>
(
x0
);
const
float
x2_tmp
=
ck
::
type_convert
<
float
>
(
x1
);
const
float
y_tmp
=
x1_tmp
*
x2_tmp
;
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_tmp
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
>
(
bhalf_t
&
y
,
const
float
&
x0
,
const
bhalf_t
&
x1
)
const
...
...
@@ -521,6 +522,71 @@ struct AddFastGelu
}
};
// E = MultiplyFastGelu(C + D)
struct
MultiplyFastGelu
{
template
<
typename
E
,
typename
C
,
typename
D
>
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D
&
d
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
,
float
,
float
>
(
float
&
e
,
const
float
&
c
,
const
float
&
d
)
const
{
const
float
x
=
c
*
d
;
FastGelu
{}.
template
operator
()
<
float
,
float
>(
e
,
x
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
,
half_t
,
half_t
>
(
half_t
&
e
,
const
half_t
&
c
,
const
half_t
&
d
)
const
{
const
half_t
x
=
c
*
d
;
ck
::
tensor_operation
::
element_wise
::
FastGelu
{}.
template
operator
()
<
half_t
,
half_t
>(
e
,
x
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
,
float
,
half_t
>
(
half_t
&
e
,
const
float
&
c
,
const
half_t
&
d
)
const
{
const
float
x0_f
=
c
*
d
;
float
x1_f
=
0
;
ck
::
tensor_operation
::
element_wise
::
FastGelu
{}.
template
operator
()
<
float
,
float
>(
x1_f
,
x0_f
);
e
=
type_convert
<
half_t
>
(
x1_f
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
,
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
e
,
const
bhalf_t
&
c
,
const
bhalf_t
&
d
)
const
{
const
float
x0_f
=
type_convert
<
float
>
(
c
)
*
type_convert
<
float
>
(
d
);
float
x1_f
=
0
;
FastGelu
{}.
template
operator
()
<
float
,
float
>(
x1_f
,
x0_f
);
e
=
type_convert
<
bhalf_t
>
(
x1_f
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
,
float
,
bhalf_t
>
(
bhalf_t
&
e
,
const
float
&
c
,
const
bhalf_t
&
d
)
const
{
const
float
x0_f
=
c
*
type_convert
<
float
>
(
d
);
float
x1_f
=
0
;
FastGelu
{}.
template
operator
()
<
float
,
float
>(
x1_f
,
x0_f
);
e
=
type_convert
<
bhalf_t
>
(
x1_f
);
}
};
// E = Silu(C + D)
struct
AddSilu
{
...
...
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
f0759faf
...
...
@@ -221,6 +221,15 @@ struct MultiplyAdd
e
=
y
;
}
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
float
,
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
e
,
const
float
&
c
,
const
bhalf_t
&
d0
,
const
bhalf_t
&
d1
)
const
{
const
bhalf_t
y
=
type_convert
<
bhalf_t
>
(
c
)
*
d0
+
d1
;
e
=
y
;
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
,
half_t
,
half_t
>
(
float
&
e
,
const
float
&
c
,
const
half_t
&
d0
,
...
...
@@ -240,6 +249,26 @@ struct MultiplyAdd
}
};
struct
MultiplyAddFastGelu
{
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
>
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
ck
::
bhalf_t
>
(
ck
::
bhalf_t
&
e
,
const
float
&
c
,
const
ck
::
bhalf_t
&
d0
,
const
ck
::
bhalf_t
&
d1
)
const
{
const
float
x0_f
=
c
*
ck
::
type_convert
<
float
>
(
d0
)
+
ck
::
type_convert
<
float
>
(
d1
);
float
x1_f
=
0
;
FastGelu
{}.
template
operator
()
<
float
,
float
>(
x1_f
,
x0_f
);
e
=
ck
::
type_convert
<
ck
::
bhalf_t
>
(
x1_f
);
}
};
// E = FastGelu(C + D0 + D1)
struct
AddAddFastGelu
{
...
...
@@ -499,6 +528,26 @@ struct UnaryTypeConvert<ck::bhalf_t, float>
}
};
struct
ConvInvscale
{
/// @brief Op to multiply convolution results by inverted scale factors
/// @param e Output after scaling
/// @param c Convolution result
/// @param d0 Input scale factor
/// @param d1 Weights scale factor
/// @param d2 Output scale factor
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
,
typename
D2
>
__host__
__device__
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
,
const
D2
&
d2
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
float
,
float
,
float
,
float
>
(
f8_t
&
e
,
const
float
&
c
,
const
float
&
d0
,
const
float
&
d1
,
const
float
&
d2
)
const
{
e
=
type_convert
<
f8_t
>
(
c
/
d0
/
d1
/
d2
);
};
};
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
f0759faf
...
...
@@ -504,6 +504,16 @@ struct FastGelu
y
=
type_convert
<
half_t
>
(
y_f
);
}
template
<
>
__host__
void
operator
()
<
bhalf_t
,
float
>
(
bhalf_t
&
y
,
const
float
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
x
);
y
=
type_convert
<
bhalf_t
>
(
y_f
);
}
template
<
>
__device__
void
operator
()
<
bhalf_t
,
float
>
(
bhalf_t
&
y
,
const
float
&
x
)
const
{
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
f0759faf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -151,7 +151,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
{
}
__host__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
...
...
@@ -275,7 +275,7 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt
{
}
__host__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
...
...
@@ -428,7 +428,7 @@ struct BlockToCTileMap_N00_M0_N01Adapt<MPerBlock, NPerBlock, void>
{
}
__host__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
...
...
@@ -900,6 +900,11 @@ struct OffsettedBlockToCTileMap
return
block_to_ctile_map_
.
CalculateGridSize
(
c_grid_desc_m_n
);
}
__host__
__device__
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
const
{
return
block_to_ctile_map_
.
CalculateGridSize
(
M
,
N
);
}
UnderlyingBlockToCTileMap
block_to_ctile_map_
;
index_t
block_start_
;
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
View file @
f0759faf
...
...
@@ -594,11 +594,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple
([
&
](
auto
)
{
return
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
);
},
Number
<
NumATensor
>
{});
#if 0
static_assert(ABlockTransferSrcScalarPerVector == ABlockTransferDstScalarPerVector_AK1,
"Src and Dst ScalarPerVector must be the same");
#endif
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v7r2
<
ThisThreadBlock
,
AsDataType
,
...
...
@@ -616,7 +611,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
uniform_sequence_gen_t
<
NumATensor
,
false
>
,
uniform_sequence_gen_t
<
NumATensor
,
AThreadTransferSrcResetCoordinateAfterRun
>
,
Sequence
<
true
>>
{
as_grid_desc_ak0_m_ak1
,
idx_as_block_begin
,
tie
(
a_block_desc_ak0_m_ak1
),
...
...
@@ -627,11 +622,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple
([
&
](
auto
)
{
return
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
);
},
Number
<
NumBTensor
>
{});
#if 0
static_assert(BBlockTransferSrcScalarPerVector == BBlockTransferDstScalarPerVector_BK1,
"Src and Dst ScalarPerVector must be the same");
#endif
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v7r2
<
ThisThreadBlock
,
BsDataType
,
...
...
@@ -649,7 +639,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
uniform_sequence_gen_t
<
NumBTensor
,
false
>
,
uniform_sequence_gen_t
<
NumBTensor
,
BThreadTransferSrcResetCoordinateAfterRun
>
,
Sequence
<
true
>>
{
bs_grid_desc_bk0_n_bk1
,
idx_bs_block_begin
,
tie
(
b_block_desc_bk0_n_bk1
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
f0759faf
...
...
@@ -257,7 +257,70 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
e_grid_desc_m_n
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
ALayout
,
typename
BLayout
,
typename
ELayout
>
__host__
__device__
static
bool
CheckTensorTransfersValidity
(
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
)
{
// Check if the vector dim is K1 or M|N
const
auto
A_vector_dim_size
=
ABlockTransferSrcVectorDim
==
2
?
KRaw
:
MRaw
;
const
auto
B_vector_dim_size
=
BBlockTransferSrcVectorDim
==
2
?
KRaw
:
NRaw
;
const
auto
E_vector_dim_size
=
NRaw
;
// check vector load for A tensor
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
if
(
!
(
A_vector_dim_size
==
KRaw
&&
A_vector_dim_size
%
ABlockTransferSrcScalarPerVector
==
0
))
return
false
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
if
(
!
(
A_vector_dim_size
==
MRaw
&&
A_vector_dim_size
%
ABlockTransferSrcScalarPerVector
==
0
))
return
false
;
}
else
{
return
false
;
}
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
{
if
(
!
(
B_vector_dim_size
==
NRaw
&&
B_vector_dim_size
%
BBlockTransferSrcScalarPerVector
==
0
))
return
false
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
if
(
!
(
B_vector_dim_size
==
KRaw
&&
B_vector_dim_size
%
BBlockTransferSrcScalarPerVector
==
0
))
return
false
;
}
else
{
return
false
;
}
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ELayout
>
)
{
if
(
!
(
E_vector_dim_size
==
NRaw
&&
E_vector_dim_size
%
CDEShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
return
false
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ELayout
>
)
{
if
(
!
(
E_vector_dim_size
==
NRaw
&&
CDEShuffleBlockTransferScalarPerVector_NPerBlock
==
1
))
return
false
;
}
else
{
return
false
;
}
return
true
;
}
template
<
typename
AGridDesc_M_K
,
typename
BGridDesc_N_K
,
typename
DsGridDesc_M_N
,
...
...
@@ -267,7 +330,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2ETileMap
&
)
[[
maybe_unused
]]
const
Block2ETileMap
&
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
...
...
@@ -285,7 +348,6 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
{
return
false
;
}
bool
valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
...
...
@@ -306,7 +368,6 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// check gridwise gemm pipeline
const
auto
num_k_loop
=
AK
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
return
false
;
...
...
@@ -938,6 +999,63 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
}
template
<
bool
HasMainKBlockLoop
,
typename
AGridDesc_MK
,
typename
BGridDesc_NK
,
typename
DsGridDesc_MN
,
typename
EGridDesc_MN
,
typename
Block2ETileMap
>
__device__
static
void
Run
(
const
void
*
__restrict__
p_a_grid_
,
const
void
*
__restrict__
p_b_grid_
,
DsGridPointer
p_ds_grid
,
void
*
__restrict__
p_e_grid_
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
,
const
AGridDesc_MK
&
a_grid_desc_m_k
,
const
BGridDesc_NK
&
b_grid_desc_n_k
,
const
DsGridDesc_MN
&
ds_grid_desc_m_n
,
const
EGridDesc_MN
&
e_grid_desc_m_n
,
const
Block2ETileMap
&
block_2_etile_map
)
{
const
auto
p_a_grid
=
reinterpret_cast
<
const
ADataType
*>
(
p_a_grid_
);
const
auto
p_b_grid
=
reinterpret_cast
<
const
BDataType
*>
(
p_b_grid_
);
const
auto
p_e_grid
=
reinterpret_cast
<
EDataType
*>
(
p_e_grid_
);
// tensor descriptors for block/thread-wise copy
const
auto
a_grid_desc_ak0_m_ak1
=
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k
);
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_MN
{}))
>
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
ds_grid_desc_mblock_mperblock_nblock_nperblock
(
j
)
=
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
[
j
]);
});
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n
);
Run
<
HasMainKBlockLoop
>
(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
View file @
f0759faf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <ostream>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
...
...
@@ -57,3 +58,16 @@ constexpr auto GridwiseGemmPipeline_Selector()
}
}
// namespace ck
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ck
::
PipelineVersion
&
p
)
{
switch
(
p
)
{
case
ck
::
PipelineVersion
::
v1
:
os
<<
"PipelineVersion::v1"
;
break
;
case
ck
::
PipelineVersion
::
v2
:
os
<<
"PipelineVersion::v2"
;
break
;
case
ck
::
PipelineVersion
::
v4
:
os
<<
"PipelineVersion::v4"
;
break
;
case
ck
::
PipelineVersion
::
weight_only
:
os
<<
"PipelineVersion::weight_only"
;
break
;
default:
os
<<
""
;
}
return
os
;
}
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment