Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f0759faf
Commit
f0759faf
authored
Apr 26, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
20ddaeba
764164b4
Changes
103
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2765 additions
and
662 deletions
+2765
-662
example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp
...60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp
+273
-0
example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp
...ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp
+274
-0
include/ck/host_utility/flush_cache.hpp
include/ck/host_utility/flush_cache.hpp
+229
-0
include/ck/stream_config.hpp
include/ck/stream_config.hpp
+3
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
+6
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
+40
-17
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
+2
-2
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp
+16
-10
include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp
...or_operation/gpu/device/device_grouped_gemm_tile_loop.hpp
+128
-0
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
...evice/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
+9
-5
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp
...gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp
+520
-461
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
...operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
+196
-135
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
...device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
+787
-0
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+75
-9
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+49
-0
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+10
-0
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+9
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
+2
-12
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+122
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
...or_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
+15
-1
No files found.
example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp
0 → 100644
View file @
f0759faf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
I8
=
int8_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
A0DataType
=
BF16
;
using
AsDataType
=
ck
::
Tuple
<
A0DataType
>
;
using
B0DataType
=
I8
;
using
B1DataType
=
BF16
;
using
BsDataType
=
ck
::
Tuple
<
B0DataType
,
B1DataType
>
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
D0DataType
=
BF16
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
EDataType
=
BF16
;
using
A0Layout
=
Row
;
using
AsLayout
=
ck
::
Tuple
<
A0Layout
>
;
using
B0Layout
=
Row
;
using
B1Layout
=
B0Layout
;
using
BsLayout
=
ck
::
Tuple
<
B0Layout
,
B1Layout
>
;
using
D0Layout
=
Row
;
using
DsLayout
=
ck
::
Tuple
<>
;
using
ELayout
=
Row
;
using
Multiply
=
ck
::
tensor_operation
::
element_wise
::
Multiply
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
FastGelu
=
ck
::
tensor_operation
::
element_wise
::
FastGelu
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
Multiply
;
using
CDEElementOp
=
FastGelu
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleABD_Xdl_CShuffle
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
AsLayout
,
BsLayout
,
DsLayout
,
ELayout
,
AsDataType
,
BsDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
1
,
256
,
128
,
128
,
64
,
8
,
4
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v4
>
;
// clang-format on
int
main
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
// GEMM shape
ck
::
index_t
M
=
4096
;
ck
::
index_t
N
=
768
;
ck
::
index_t
K
=
6144
;
ck
::
index_t
StrideA
=
K
;
ck
::
index_t
StrideB
=
N
;
ck
::
index_t
StrideD
=
0
;
ck
::
index_t
StrideE
=
N
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
11
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
StrideA
=
std
::
stoi
(
argv
[
7
]);
StrideB
=
std
::
stoi
(
argv
[
8
]);
StrideD
=
std
::
stoi
(
argv
[
9
]);
StrideE
=
std
::
stoi
(
argv
[
10
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE
\n
"
);
exit
(
0
);
}
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
using
namespace
ck
::
literals
;
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
};
Tensor
<
A0DataType
>
a0_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
A0Layout
{}));
Tensor
<
B0DataType
>
b0_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
B0Layout
{}));
Tensor
<
B1DataType
>
b1_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
0
,
B1Layout
{}));
Tensor
<
D0DataType
>
d_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD
,
D0Layout
{}));
Tensor
<
EDataType
>
e_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
Tensor
<
EDataType
>
e_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
std
::
cout
<<
"a0_m_k: "
<<
a0_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_k_n: "
<<
b0_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b1_k_n: "
<<
b1_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_m_n: "
<<
d_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_m_n: "
<<
e_m_n_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
5
,
5
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
5
,
5
});
b1_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
0
,
5
});
d_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D0DataType
>
{
-
5
,
5
});
break
;
default:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
0.0
,
1.0
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
5
,
5
});
b1_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
B1DataType
>
{
0
,
5
});
d_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D0DataType
>
{
-
0.5
,
0.5
});
}
DeviceMem
a0_device_buf
(
sizeof
(
A0DataType
)
*
a0_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b0_device_buf
(
sizeof
(
B0DataType
)
*
b0_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b1_device_buf
(
sizeof
(
B1DataType
)
*
b1_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
D0DataType
)
*
d_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a0_device_buf
.
ToDevice
(
a0_m_k
.
mData
.
data
());
b0_device_buf
.
ToDevice
(
b0_k_n
.
mData
.
data
());
b1_device_buf
.
ToDevice
(
b1_k_n
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_m_n
.
mData
.
data
());
e_device_buf
.
ToDevice
(
e_m_n_device_result
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
constexpr
ck
::
index_t
NumATensor
=
1
;
constexpr
ck
::
index_t
NumBTensor
=
2
;
constexpr
ck
::
index_t
NumDTensor
=
0
;
// do GEMM
auto
device_op
=
DeviceOpInstance
{};
auto
invoker
=
device_op
.
MakeInvoker
();
auto
argument
=
device_op
.
MakeArgument
(
std
::
array
<
const
void
*
,
NumATensor
>
{
a0_device_buf
.
GetDeviceBuffer
()},
std
::
array
<
const
void
*
,
NumBTensor
>
{
b0_device_buf
.
GetDeviceBuffer
(),
b1_device_buf
.
GetDeviceBuffer
()},
std
::
array
<
const
void
*
,
NumDTensor
>
{},
e_device_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
std
::
array
<
ck
::
index_t
,
NumATensor
>
{
StrideA
},
std
::
array
<
ck
::
index_t
,
NumBTensor
>
{
StrideB
,
0
},
std
::
array
<
ck
::
index_t
,
NumDTensor
>
{},
StrideE
,
a_element_op
,
b_element_op
,
cde_element_op
);
if
(
!
device_op
.
IsSupportedArgument
(
argument
))
{
throw
std
::
runtime_error
(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
);
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
A0DataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
EDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
if
(
do_verification
)
{
Tensor
<
CShuffleDataType
>
c_m_n
({
M
,
N
});
Tensor
<
A0DataType
>
a_m_k
({
M
,
K
});
Tensor
<
B1DataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
B0Layout
{}));
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
b_element_op
(
b_k_n
(
k
,
n
),
b0_k_n
(
k
,
n
),
b1_k_n
(
k
,
n
));
}
}
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
A0DataType
,
B1DataType
,
CShuffleDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a0_m_k
,
b_k_n
,
c_m_n
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
cde_element_op
(
e_m_n_host_result
(
m
,
n
),
c_m_n
(
m
,
n
));
}
}
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
)
?
0
:
1
;
}
return
0
;
}
example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp
0 → 100644
View file @
f0759faf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
I8
=
int8_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
A0DataType
=
BF16
;
using
AsDataType
=
ck
::
Tuple
<
A0DataType
>
;
using
B0DataType
=
I8
;
using
B1DataType
=
BF16
;
using
BsDataType
=
ck
::
Tuple
<
B0DataType
>
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
D0DataType
=
BF16
;
using
DsDataType
=
ck
::
Tuple
<
B1DataType
,
D0DataType
>
;
using
EDataType
=
BF16
;
using
A0Layout
=
Row
;
using
AsLayout
=
ck
::
Tuple
<
A0Layout
>
;
using
B0Layout
=
Row
;
using
B1Layout
=
B0Layout
;
using
BsLayout
=
ck
::
Tuple
<
B0Layout
>
;
using
D0Layout
=
Row
;
using
DsLayout
=
ck
::
Tuple
<
B1Layout
,
D0Layout
>
;
using
ELayout
=
Row
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
MultiplyAddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
MultiplyAddFastGelu
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
MultiplyAddFastGelu
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleABD_Xdl_CShuffle
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
AsLayout
,
BsLayout
,
DsLayout
,
ELayout
,
AsDataType
,
BsDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
1
,
256
,
128
,
128
,
64
,
8
,
4
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v4
>
;
// clang-format on
int
main
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
// GEMM shape
ck
::
index_t
M
=
4096
;
ck
::
index_t
N
=
768
;
ck
::
index_t
K
=
6144
;
ck
::
index_t
StrideA
=
K
;
ck
::
index_t
StrideB
=
N
;
ck
::
index_t
StrideD
=
0
;
ck
::
index_t
StrideE
=
N
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
11
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
StrideA
=
std
::
stoi
(
argv
[
7
]);
StrideB
=
std
::
stoi
(
argv
[
8
]);
StrideD
=
std
::
stoi
(
argv
[
9
]);
StrideE
=
std
::
stoi
(
argv
[
10
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE
\n
"
);
exit
(
0
);
}
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
using
namespace
ck
::
literals
;
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
};
Tensor
<
A0DataType
>
a0_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
A0Layout
{}));
Tensor
<
B0DataType
>
b0_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
B0Layout
{}));
Tensor
<
B1DataType
>
b1_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
0
,
B1Layout
{}));
Tensor
<
D0DataType
>
d_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD
,
D0Layout
{}));
Tensor
<
EDataType
>
e_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
Tensor
<
EDataType
>
e_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
std
::
cout
<<
"a0_m_k: "
<<
a0_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_k_n: "
<<
b0_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b1_k_n: "
<<
b1_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_m_n: "
<<
d_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_m_n: "
<<
e_m_n_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
5
,
5
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
5
,
5
});
b1_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
0
,
5
});
d_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D0DataType
>
{
-
5
,
5
});
break
;
default:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
0.0
,
1.0
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
5
,
5
});
b1_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
B1DataType
>
{
0
,
5
});
d_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D0DataType
>
{
-
0.5
,
0.5
});
}
DeviceMem
a0_device_buf
(
sizeof
(
A0DataType
)
*
a0_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b0_device_buf
(
sizeof
(
B0DataType
)
*
b0_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b1_device_buf
(
sizeof
(
B1DataType
)
*
b1_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
D0DataType
)
*
d_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a0_device_buf
.
ToDevice
(
a0_m_k
.
mData
.
data
());
b0_device_buf
.
ToDevice
(
b0_k_n
.
mData
.
data
());
b1_device_buf
.
ToDevice
(
b1_k_n
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_m_n
.
mData
.
data
());
e_device_buf
.
ToDevice
(
e_m_n_device_result
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
constexpr
ck
::
index_t
NumATensor
=
1
;
constexpr
ck
::
index_t
NumBTensor
=
1
;
constexpr
ck
::
index_t
NumDTensor
=
2
;
// do GEMM
auto
device_op
=
DeviceOpInstance
{};
auto
invoker
=
device_op
.
MakeInvoker
();
auto
argument
=
device_op
.
MakeArgument
(
std
::
array
<
const
void
*
,
NumATensor
>
{
a0_device_buf
.
GetDeviceBuffer
()},
std
::
array
<
const
void
*
,
NumBTensor
>
{
b0_device_buf
.
GetDeviceBuffer
()},
std
::
array
<
const
void
*
,
NumDTensor
>
{
b1_device_buf
.
GetDeviceBuffer
(),
d_device_buf
.
GetDeviceBuffer
()},
e_device_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
std
::
array
<
ck
::
index_t
,
NumATensor
>
{
StrideA
},
std
::
array
<
ck
::
index_t
,
NumBTensor
>
{
StrideB
},
std
::
array
<
ck
::
index_t
,
NumDTensor
>
{
0
,
StrideD
},
StrideE
,
a_element_op
,
b_element_op
,
cde_element_op
);
if
(
!
device_op
.
IsSupportedArgument
(
argument
))
{
throw
std
::
runtime_error
(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
);
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
A0DataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
EDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
if
(
do_verification
)
{
Tensor
<
CShuffleDataType
>
c_m_n
({
M
,
N
});
Tensor
<
A0DataType
>
a_m_k
({
M
,
K
});
Tensor
<
B1DataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
B0Layout
{}));
#if 0
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K; ++k)
{
b_element_op(b_k_n(k, n), b0_k_n(k, n), b1_k_n(k, n));
}
}
#endif
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
A0DataType
,
B0DataType
,
CShuffleDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a0_m_k
,
b0_k_n
,
c_m_n
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
cde_element_op
(
e_m_n_host_result
(
m
,
n
),
c_m_n
(
m
,
n
),
b1_k_n
(
0
,
n
),
d_m_n
(
m
,
n
));
}
}
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
)
?
0
:
1
;
}
return
0
;
}
include/ck/host_utility/flush_cache.hpp
0 → 100644
View file @
f0759faf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <hip/hip_runtime.h>
#include <set>
#include "ck/ck.hpp"
#include "ck/stream_config.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/utility/flush_icache.hpp"
namespace
ck
{
namespace
utility
{
template
<
typename
Argument
>
struct
RotatingMemWrapper
{
using
ADataType
=
decltype
(
Argument
::
p_a_grid
);
using
BDataType
=
decltype
(
Argument
::
p_b_grid
);
RotatingMemWrapper
()
=
delete
;
RotatingMemWrapper
(
Argument
&
arg_
,
std
::
size_t
rotating_count_
,
std
::
size_t
size_a_
,
std
::
size_t
size_b_
)
:
arg
(
arg_
),
rotating_count
(
rotating_count_
),
size_a
(
size_a_
),
size_b
(
size_b_
)
{
p_a_grids
.
push_back
(
arg
.
p_a_grid
);
p_b_grids
.
push_back
(
arg
.
p_b_grid
);
for
(
size_t
i
=
1
;
i
<
rotating_count
;
i
++
)
{
{
void
*
pADeviceBuf
;
hip_check_error
(
hipMalloc
(
static_cast
<
void
**>
(
&
pADeviceBuf
),
size_a_
));
hip_check_error
(
hipMemcpy
(
static_cast
<
void
*>
(
pADeviceBuf
),
const_cast
<
void
*>
(
p_a_grids
[
0
]),
size_a_
,
hipMemcpyDeviceToDevice
));
p_a_grids
.
push_back
(
pADeviceBuf
);
}
{
void
*
pBDeviceBuf
;
hip_check_error
(
hipMalloc
(
static_cast
<
void
**>
(
&
pBDeviceBuf
),
size_b_
));
hip_check_error
(
hipMemcpy
(
static_cast
<
void
*>
(
pBDeviceBuf
),
const_cast
<
void
*>
(
p_b_grids
[
0
]),
size_b_
,
hipMemcpyDeviceToDevice
));
p_b_grids
.
push_back
(
pBDeviceBuf
);
}
}
}
void
Next
()
{
if
(
rotating_count
>
1
)
{
std
::
size_t
idx
=
iter
++
%
rotating_count
;
arg
.
p_a_grid
=
reinterpret_cast
<
ADataType
>
(
p_a_grids
[
idx
]);
arg
.
p_b_grid
=
reinterpret_cast
<
BDataType
>
(
p_b_grids
[
idx
]);
}
}
void
Print
()
{
std
::
cout
<<
"RotatingMemWrapper: { size_a: "
<<
size_a
<<
", size_b: "
<<
size_b
<<
", rotating_count: "
<<
rotating_count
<<
"}"
<<
std
::
endl
;
}
~
RotatingMemWrapper
()
{
if
(
rotating_count
>
1
)
{
// restore ptr
arg
.
p_a_grid
=
reinterpret_cast
<
ADataType
>
(
p_a_grids
[
0
]);
arg
.
p_b_grid
=
reinterpret_cast
<
BDataType
>
(
p_b_grids
[
0
]);
// free device mem
for
(
size_t
i
=
1
;
i
<
rotating_count
;
i
++
)
{
hip_check_error
(
hipFree
(
const_cast
<
void
*>
(
p_a_grids
[
i
])));
hip_check_error
(
hipFree
(
const_cast
<
void
*>
(
p_b_grids
[
i
])));
}
}
}
private:
Argument
&
arg
;
std
::
size_t
iter
=
0
;
std
::
size_t
rotating_count
=
1
;
std
::
size_t
size_a
=
0
;
std
::
size_t
size_b
=
0
;
std
::
vector
<
const
void
*>
p_a_grids
;
std
::
vector
<
const
void
*>
p_b_grids
;
};
inline
void
flush_icache
()
{
hipDeviceProp_t
deviceProps
;
hip_check_error
(
hipGetDeviceProperties
(
&
deviceProps
,
0
));
int32_t
gpu_block3
=
deviceProps
.
multiProcessorCount
*
60
;
ck
::
flush_icache
<<<
dim3
(
gpu_block3
),
dim3
(
64
),
0
,
nullptr
>>>
();
hip_check_error
(
hipGetLastError
());
}
// if TimePrePress == false, return time does not include preprocess's time
template
<
bool
TimePreprocess
,
typename
Args
,
typename
F
,
typename
PreProcessFunc
>
float
launch_and_time_kernel_with_preprocess
(
const
StreamConfig
&
stream_config
,
PreProcessFunc
preprocess
,
F
kernel
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
Args
&
args
)
{
#if CK_TIME_KERNEL
#define MEDIAN 1
if
(
stream_config
.
time_kernel_
)
{
#if DEBUG_LOG
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
__func__
,
grid_dim
.
x
,
grid_dim
.
y
,
grid_dim
.
z
,
block_dim
.
x
,
block_dim
.
y
,
block_dim
.
z
);
printf
(
"Warm up %d times
\n
"
,
stream_config
.
cold_niters_
);
#endif
// warm up
for
(
int
i
=
0
;
i
<
stream_config
.
cold_niters_
;
++
i
)
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
);
hip_check_error
(
hipGetLastError
());
}
const
int
nrepeat
=
stream_config
.
nrepeat_
;
if
(
nrepeat
==
0
)
{
return
0.0
;
}
#if DEBUG_LOG
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
#endif
#if MEDIAN
std
::
set
<
float
>
times
;
#else
float
total_time
=
0
;
#endif
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
if
constexpr
(
!
TimePreprocess
)
{
preprocess
();
}
hipEvent_t
start
,
stop
;
hip_check_error
(
hipEventCreate
(
&
start
));
hip_check_error
(
hipEventCreate
(
&
stop
));
hip_check_error
(
hipDeviceSynchronize
());
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
// calculate preprocess time
if
constexpr
(
TimePreprocess
)
{
preprocess
();
}
// run real kernel
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
);
hip_check_error
(
hipGetLastError
());
// end real kernel
hip_check_error
(
hipEventRecord
(
stop
,
stream_config
.
stream_id_
));
hip_check_error
(
hipEventSynchronize
(
stop
));
float
cur_time
=
0
;
hip_check_error
(
hipEventElapsedTime
(
&
cur_time
,
start
,
stop
));
#if MEDIAN
times
.
insert
(
cur_time
);
#else
total_time
+=
cur_time
;
#endif
#if DEBUG_LOG
std
::
cout
<<
"i: "
<<
i
<<
" cur_time: "
<<
cur_time
<<
std
::
endl
;
printf
(
"args.p_a_grid: %p, args.p_b_grid:%p
\n
"
,
static_cast
<
const
void
*>
(
args
.
p_a_grid
),
static_cast
<
const
void
*>
(
args
.
p_b_grid
));
#endif
}
#if MEDIAN
auto
mid
=
times
.
begin
();
std
::
advance
(
mid
,
(
nrepeat
-
1
)
/
2
);
if
(
nrepeat
%
2
==
1
)
{
return
*
mid
;
}
else
{
auto
mid_next
=
mid
;
std
::
advance
(
mid_next
,
1
);
return
(
*
mid
+
*
mid_next
)
/
2
;
}
#else
return
total_time
/
nrepeat
;
#endif
}
else
{
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
);
hip_check_error
(
hipGetLastError
());
return
0
;
}
#else
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
);
hip_check_error
(
hipGetLastError
());
return
0
;
#endif
}
}
// namespace utility
}
// namespace ck
include/ck/stream_config.hpp
View file @
f0759faf
...
...
@@ -13,4 +13,7 @@ struct StreamConfig
int
log_level_
=
0
;
int
cold_niters_
=
5
;
int
nrepeat_
=
50
;
bool
flush_cache
=
false
;
int
rotating_count
=
1
;
};
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
View file @
f0759faf
...
...
@@ -140,8 +140,10 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
using
Base
::
AMmaKStride
;
using
Base
::
BMmaKStride
;
static
constexpr
index_t
WgpPerCU
=
(
4
*
warpSize
/
BlockSize
)
>=
1
?
4
*
warpSize
/
BlockSize
:
1
;
static
constexpr
index_t
FullMemBandPrefetchStages
=
math
::
integer_divide_ceil
(
32768
/
(
4
*
warpSize
/
BlockSize
)
,
32768
/
WgpPerCU
,
(
MPerBlock
*
sizeof
(
ADataType
)
+
NPerBlock
*
sizeof
(
BDataType
))
*
KPerBlock
);
static
constexpr
index_t
PrefetchStages
=
FullMemBandPrefetchStages
>=
2
...
...
@@ -631,8 +633,10 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
static
constexpr
index_t
KPerInnerLoop
=
math
::
max
(
KPerThread
/
NumMacClusters
,
KPack
);
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPerInnerLoop
;
static
constexpr
index_t
WgpPerCU
=
(
4
*
warpSize
/
BlockSize
)
>=
1
?
4
*
warpSize
/
BlockSize
:
1
;
static
constexpr
index_t
FullMemBandPrefetchStages
=
math
::
integer_divide_ceil
(
32768
/
(
4
*
warpSize
/
BlockSize
)
,
32768
/
WgpPerCU
,
(
MPerBlock
*
sizeof
(
ADataType
)
+
NPerBlock
*
sizeof
(
BDataType
))
*
KPerBlock
);
static
constexpr
index_t
PrefetchStages
=
FullMemBandPrefetchStages
>=
2
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
View file @
f0759faf
...
...
@@ -184,19 +184,22 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
constexpr
auto
ds_read_b_issue_cycle
=
HotLoopInstList
::
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
8
:
4
;
constexpr
auto
ds_read_a_mfma_rate
=
(
mfma_cycle
-
8
+
ds_read_a_issue_cycle
-
1
)
/
ds_read_a_issue_cycle
;
(
mfma_cycle
-
4
+
2
*
ds_read_a_issue_cycle
-
1
)
/
(
2
*
ds_read_a_issue_cycle
)
;
constexpr
auto
ds_read_b_mfma_rate
=
(
mfma_cycle
-
8
+
ds_read_b_issue_cycle
-
1
)
/
ds_read_b_issue_cycle
;
(
mfma_cycle
-
4
+
2
*
ds_read_b_issue_cycle
-
1
)
/
(
2
*
ds_read_b_issue_cycle
);
constexpr
auto
num_dsread_a_mfma
=
(
num_ds_read_inst_a
+
ds_read_a_mfma_rate
-
1
)
/
ds_read_a_mfma_rate
;
constexpr
auto
num_dsread_b_mfma
=
(
num_ds_read_inst_b
+
ds_read_b_mfma_rate
-
1
)
/
ds_read_b_mfma_rate
;
// stage 1
// Separate this part?
constexpr
auto
num_mfma_per_ds_read
=
sizeof
(
ComputeDataType
)
/
sizeof
(
ADataType
)
>
sizeof
(
ComputeDataType
)
/
sizeof
(
BDataType
)
?
sizeof
(
ComputeDataType
)
/
sizeof
(
ADataType
)
:
sizeof
(
ComputeDataType
)
/
sizeof
(
BDataType
);
constexpr
auto
num_mfma_stage1
=
num_mfma_inst
-
num_mfma_per_ds_read
*
(
num_ds_read_inst_a
/
ds_read_a_mfma_rate
+
num_ds_read_inst_b
/
ds_read_b_mfma_rate
);
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// sizeof(ComputeDataType) / sizeof(BDataType)
// ? sizeof(ComputeDataType) / sizeof(ADataType)
// : sizeof(ComputeDataType) / sizeof(BDataType);
constexpr
auto
num_mfma_stage1
=
num_mfma_inst
-
(
num_dsread_a_mfma
+
num_dsread_b_mfma
);
constexpr
auto
num_mfma_per_issue
=
num_mfma_stage1
/
(
num_buffer_load_inst_a
+
num_buffer_load_inst_b
);
constexpr
auto
num_dswrite_per_issue_a
=
num_ds_write_inst_a
/
num_buffer_load_inst_a
;
...
...
@@ -226,16 +229,36 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
});
// stage 2
static_for
<
0
,
num_ds_read_inst_a
/
ds_read_a_mfma_rate
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x100
,
ds_read_a_mfma_rate
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_ds_read
,
0
);
// MFMA
static_for
<
0
,
num_dsread_a_mfma
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
((
num_ds_read_inst_a
-
(
i
+
1
)
*
ds_read_a_mfma_rate
)
>=
ds_read_a_mfma_rate
)
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
ds_read_a_mfma_rate
,
0
);
// DS read
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst_a
-
(
num_dsread_a_mfma
-
1
)
*
ds_read_a_mfma_rate
,
0
);
// DS read
}
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
static_for
<
0
,
num_ds_read_inst_b
/
ds_read_b_mfma_rate
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x100
,
ds_read_b_mfma_rate
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_ds_read
,
0
);
// MFMA
static_for
<
0
,
num_dsread_b_mfma
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
((
num_ds_read_inst_b
-
(
i
+
1
)
*
ds_read_b_mfma_rate
)
>=
ds_read_b_mfma_rate
)
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
ds_read_b_mfma_rate
,
0
);
// DS read
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst_b
-
(
num_dsread_b_mfma
-
1
)
*
ds_read_b_mfma_rate
,
0
);
// DS read
}
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
}
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
View file @
f0759faf
...
...
@@ -194,9 +194,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
constexpr
auto
ds_read_b_issue_cycle
=
HotLoopInstList
::
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
8
:
4
;
constexpr
auto
ds_read_a_mfma_rate
=
(
mfma_cycle
-
8
+
ds_read_a_issue_cycle
-
1
)
/
(
2
*
ds_read_a_issue_cycle
);
(
mfma_cycle
-
4
+
2
*
ds_read_a_issue_cycle
-
1
)
/
(
2
*
ds_read_a_issue_cycle
);
constexpr
auto
ds_read_b_mfma_rate
=
(
mfma_cycle
-
8
+
ds_read_b_issue_cycle
-
1
)
/
(
2
*
ds_read_b_issue_cycle
);
(
mfma_cycle
-
4
+
2
*
ds_read_b_issue_cycle
-
1
)
/
(
2
*
ds_read_b_issue_cycle
);
constexpr
auto
num_dsread_stage1_a
=
num_ds_read_inst_a
/
KRepeat
*
(
KRepeat
-
1
);
constexpr
auto
num_dsread_stage1_b
=
num_ds_read_inst_b
/
KRepeat
*
(
KRepeat
-
1
);
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp
View file @
f0759faf
...
...
@@ -41,7 +41,8 @@ template <typename ThreadGroup,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
,
typename
ThreadTransferSrcResetCoordinateAfterRunFlags
,
typename
ThreadTransferDstResetCoordinateAfterRunFlags
>
typename
ThreadTransferDstResetCoordinateAfterRunFlags
,
index_t
NumThreadScratch
=
1
>
struct
ThreadGroupTensorSliceTransfer_v7r2
{
static
constexpr
index_t
nDim
=
...
...
@@ -100,7 +101,7 @@ struct ThreadGroupTensorSliceTransfer_v7r2
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_i
d
()));
make_multi_index
(
ThreadGroup
::
GetThreadI
d
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
...
...
@@ -117,29 +118,33 @@ struct ThreadGroupTensorSliceTransfer_v7r2
}
}
template
<
typename
SrcBuffers
>
__device__
void
RunRead
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
)
template
<
typename
SrcBuffers
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunRead
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunRead
(
src_descs
,
src_bufs
);
threadwise_transfer_
.
RunRead
(
src_descs
,
src_bufs
,
thread_scratch_id
);
}
}
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
template
<
typename
DstBuffers
>
__device__
void
RunWrite
(
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
)
template
<
typename
DstBuffers
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunWrite
(
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
if
constexpr
(
is_detected
<
is_tuple
,
decltype
(
dst_bufs
)
>::
value
)
threadwise_transfer_
.
RunWrite
(
dst_descs
,
dst_bufs
);
threadwise_transfer_
.
RunWrite
(
dst_descs
,
dst_bufs
,
thread_scratch_id
);
else
threadwise_transfer_
.
RunWrite
(
dst_descs
,
tie
(
dst_bufs
));
threadwise_transfer_
.
RunWrite
(
dst_descs
,
tie
(
dst_bufs
)
,
thread_scratch_id
);
}
}
...
...
@@ -206,7 +211,8 @@ struct ThreadGroupTensorSliceTransfer_v7r2
SrcScalarPerVector
,
DstScalarPerVector
,
ThreadTransferSrcResetCoordinateAfterRunFlags
,
ThreadTransferDstResetCoordinateAfterRunFlags
>
;
ThreadTransferDstResetCoordinateAfterRunFlags
,
NumThreadScratch
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp
0 → 100644
View file @
f0759faf
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <iostream>
#include <vector>
#include <sstream>
#include "device_grouped_gemm.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
///
/// @brief Structure representing single GEMM problem arguments.
///
/// The pointer to the vector of those structures is passed to the GroupedGEMM entry
/// point kernel.
///
/// @tparam NumDTensor The number of D input tensors.
///
template
<
index_t
NumDTensor
=
0
>
struct
GroupedGemmTileLoopKernelArguments
{
__host__
__device__
GroupedGemmTileLoopKernelArguments
(
const
void
*
p_a_grid_
,
const
void
*
p_b_grid_
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid_
,
void
*
p_e_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs_
,
index_t
StrideE_
)
:
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_ds_grid
{
p_ds_grid_
},
p_e_grid
{
p_e_grid_
},
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideDs
{
StrideDs_
},
StrideE
{
StrideE_
}
{
}
const
void
*
p_a_grid
;
const
void
*
p_b_grid
;
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
;
void
*
p_e_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
index_t
StrideE
;
void
Print
()
const
{
std
::
stringstream
str
;
for
(
auto
sd
:
StrideDs
)
str
<<
sd
<<
","
;
std
::
cout
<<
"arg {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SE:"
<<
StrideE
<<
", "
<<
"SDs: {"
<<
str
.
str
()
<<
"}"
<<
"}"
<<
std
::
endl
;
}
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
struct
DeviceGroupedGemmTileLoop
:
public
DeviceGroupedGemm
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
//----------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
=
0
;
//----------------------------------------------------------------------------------------------
/// @brief Gets the device kernel argument size.
///
/// @param[in] p_arg The pointer to the Device op Argument.
///
/// @return The device kernel argument size.
///
virtual
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
View file @
f0759faf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -627,7 +627,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg
.
a_max_read_elems_
%
ABlockTransferSrcScalarPerVector
==
0
;
const
bool
valid_a_access_dim_m
=
ABlockTransferSrcVectorDim
==
1
&&
arg
.
a_mz_consecutive_
;
const
bool
valid_a_access_dim_k
=
ABlockTransferSrcVectorDim
==
2
&&
arg
.
a_kz_consecutive_
;
const
bool
valid_a_access_dim
=
valid_a_access_dim_m
||
valid_a_access_dim_k
;
const
bool
valid_a_access_dim
=
valid_a_access_dim_m
||
valid_a_access_dim_k
||
ABlockTransferSrcScalarPerVector
==
1
;
if
(
!
(
valid_a_vector_size
&&
valid_a_access_dim
))
{
return
false
;
...
...
@@ -637,7 +638,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg
.
b_max_read_elems_
%
BBlockTransferSrcScalarPerVector
==
0
;
const
bool
valid_b_access_dim_n
=
BBlockTransferSrcVectorDim
==
1
&&
arg
.
b_nz_consecutive_
;
const
bool
valid_b_access_dim_k
=
BBlockTransferSrcVectorDim
==
2
&&
arg
.
b_kz_consecutive_
;
const
bool
valid_b_access_dim
=
valid_b_access_dim_n
||
valid_b_access_dim_k
;
const
bool
valid_b_access_dim
=
valid_b_access_dim_n
||
valid_b_access_dim_k
||
BBlockTransferSrcScalarPerVector
==
1
;
if
(
!
(
valid_b_vector_size
&&
valid_b_access_dim
))
{
return
false
;
...
...
@@ -648,7 +650,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const
bool
valid_d_vector_size
=
arg
.
ds_max_read_elems_
[
i
]
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
// Vector read of Ds is always on N dimension.
const
bool
valid_d_access_dim
=
arg
.
ds_nz_consecutive_
[
i
];
const
bool
valid_d_access_dim
=
arg
.
ds_nz_consecutive_
[
i
]
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
;
if
(
!
(
valid_d_vector_size
&&
valid_d_access_dim
))
{
valid_ds_access
=
false
;
...
...
@@ -662,7 +665,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const
bool
valid_e_vector_size
=
arg
.
e_max_write_elems_
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
// Vector write of E is always on N dimension.
const
bool
valid_e_access_dim
=
arg
.
e_nz_consecutive_
;
const
bool
valid_e_access_dim
=
arg
.
e_nz_consecutive_
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
;
if
(
!
(
valid_e_vector_size
&&
valid_e_access_dim
))
{
return
false
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp
View file @
f0759faf
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
View file @
f0759faf
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
0 → 100644
View file @
f0759faf
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
f0759faf
...
...
@@ -92,15 +92,6 @@ struct Add
};
};
struct
Scales
{
template
<
typename
Y
,
typename
X0
,
typename
X1
>
__host__
__device__
constexpr
void
operator
()(
Y
&
y
,
const
X0
&
x0
,
const
X1
&
x1
)
const
{
y
=
ck
::
type_convert
<
Y
>
(
ck
::
type_convert
<
float
>
(
x0
)
*
ck
::
type_convert
<
float
>
(
x1
));
}
};
struct
Max
{
template
<
typename
Y
,
typename
X0
,
typename
X1
>
...
...
@@ -188,6 +179,16 @@ struct Multiply
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_tmp
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
>
(
bhalf_t
&
y
,
const
int8_t
&
x0
,
const
bhalf_t
&
x1
)
const
{
const
float
x1_tmp
=
ck
::
type_convert
<
float
>
(
x0
);
const
float
x2_tmp
=
ck
::
type_convert
<
float
>
(
x1
);
const
float
y_tmp
=
x1_tmp
*
x2_tmp
;
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_tmp
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
>
(
bhalf_t
&
y
,
const
float
&
x0
,
const
bhalf_t
&
x1
)
const
...
...
@@ -521,6 +522,71 @@ struct AddFastGelu
}
};
// E = MultiplyFastGelu(C + D)
struct
MultiplyFastGelu
{
template
<
typename
E
,
typename
C
,
typename
D
>
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D
&
d
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
,
float
,
float
>
(
float
&
e
,
const
float
&
c
,
const
float
&
d
)
const
{
const
float
x
=
c
*
d
;
FastGelu
{}.
template
operator
()
<
float
,
float
>(
e
,
x
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
,
half_t
,
half_t
>
(
half_t
&
e
,
const
half_t
&
c
,
const
half_t
&
d
)
const
{
const
half_t
x
=
c
*
d
;
ck
::
tensor_operation
::
element_wise
::
FastGelu
{}.
template
operator
()
<
half_t
,
half_t
>(
e
,
x
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
,
float
,
half_t
>
(
half_t
&
e
,
const
float
&
c
,
const
half_t
&
d
)
const
{
const
float
x0_f
=
c
*
d
;
float
x1_f
=
0
;
ck
::
tensor_operation
::
element_wise
::
FastGelu
{}.
template
operator
()
<
float
,
float
>(
x1_f
,
x0_f
);
e
=
type_convert
<
half_t
>
(
x1_f
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
,
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
e
,
const
bhalf_t
&
c
,
const
bhalf_t
&
d
)
const
{
const
float
x0_f
=
type_convert
<
float
>
(
c
)
*
type_convert
<
float
>
(
d
);
float
x1_f
=
0
;
FastGelu
{}.
template
operator
()
<
float
,
float
>(
x1_f
,
x0_f
);
e
=
type_convert
<
bhalf_t
>
(
x1_f
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
,
float
,
bhalf_t
>
(
bhalf_t
&
e
,
const
float
&
c
,
const
bhalf_t
&
d
)
const
{
const
float
x0_f
=
c
*
type_convert
<
float
>
(
d
);
float
x1_f
=
0
;
FastGelu
{}.
template
operator
()
<
float
,
float
>(
x1_f
,
x0_f
);
e
=
type_convert
<
bhalf_t
>
(
x1_f
);
}
};
// E = Silu(C + D)
struct
AddSilu
{
...
...
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
f0759faf
...
...
@@ -221,6 +221,15 @@ struct MultiplyAdd
e
=
y
;
}
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
float
,
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
e
,
const
float
&
c
,
const
bhalf_t
&
d0
,
const
bhalf_t
&
d1
)
const
{
const
bhalf_t
y
=
type_convert
<
bhalf_t
>
(
c
)
*
d0
+
d1
;
e
=
y
;
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
,
half_t
,
half_t
>
(
float
&
e
,
const
float
&
c
,
const
half_t
&
d0
,
...
...
@@ -240,6 +249,26 @@ struct MultiplyAdd
}
};
struct
MultiplyAddFastGelu
{
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
>
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
ck
::
bhalf_t
>
(
ck
::
bhalf_t
&
e
,
const
float
&
c
,
const
ck
::
bhalf_t
&
d0
,
const
ck
::
bhalf_t
&
d1
)
const
{
const
float
x0_f
=
c
*
ck
::
type_convert
<
float
>
(
d0
)
+
ck
::
type_convert
<
float
>
(
d1
);
float
x1_f
=
0
;
FastGelu
{}.
template
operator
()
<
float
,
float
>(
x1_f
,
x0_f
);
e
=
ck
::
type_convert
<
ck
::
bhalf_t
>
(
x1_f
);
}
};
// E = FastGelu(C + D0 + D1)
struct
AddAddFastGelu
{
...
...
@@ -499,6 +528,26 @@ struct UnaryTypeConvert<ck::bhalf_t, float>
}
};
struct
ConvInvscale
{
/// @brief Op to multiply convolution results by inverted scale factors
/// @param e Output after scaling
/// @param c Convolution result
/// @param d0 Input scale factor
/// @param d1 Weights scale factor
/// @param d2 Output scale factor
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
,
typename
D2
>
__host__
__device__
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
,
const
D2
&
d2
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
float
,
float
,
float
,
float
>
(
f8_t
&
e
,
const
float
&
c
,
const
float
&
d0
,
const
float
&
d1
,
const
float
&
d2
)
const
{
e
=
type_convert
<
f8_t
>
(
c
/
d0
/
d1
/
d2
);
};
};
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
f0759faf
...
...
@@ -504,6 +504,16 @@ struct FastGelu
y
=
type_convert
<
half_t
>
(
y_f
);
}
template
<
>
__host__
void
operator
()
<
bhalf_t
,
float
>
(
bhalf_t
&
y
,
const
float
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
x
);
y
=
type_convert
<
bhalf_t
>
(
y_f
);
}
template
<
>
__device__
void
operator
()
<
bhalf_t
,
float
>
(
bhalf_t
&
y
,
const
float
&
x
)
const
{
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
f0759faf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -151,7 +151,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
{
}
__host__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
...
...
@@ -275,7 +275,7 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt
{
}
__host__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
...
...
@@ -428,7 +428,7 @@ struct BlockToCTileMap_N00_M0_N01Adapt<MPerBlock, NPerBlock, void>
{
}
__host__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
...
...
@@ -900,6 +900,11 @@ struct OffsettedBlockToCTileMap
return
block_to_ctile_map_
.
CalculateGridSize
(
c_grid_desc_m_n
);
}
__host__
__device__
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
const
{
return
block_to_ctile_map_
.
CalculateGridSize
(
M
,
N
);
}
UnderlyingBlockToCTileMap
block_to_ctile_map_
;
index_t
block_start_
;
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
View file @
f0759faf
...
...
@@ -594,11 +594,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple
([
&
](
auto
)
{
return
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
);
},
Number
<
NumATensor
>
{});
#if 0
static_assert(ABlockTransferSrcScalarPerVector == ABlockTransferDstScalarPerVector_AK1,
"Src and Dst ScalarPerVector must be the same");
#endif
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v7r2
<
ThisThreadBlock
,
AsDataType
,
...
...
@@ -616,7 +611,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
uniform_sequence_gen_t
<
NumATensor
,
false
>
,
uniform_sequence_gen_t
<
NumATensor
,
AThreadTransferSrcResetCoordinateAfterRun
>
,
Sequence
<
true
>>
{
as_grid_desc_ak0_m_ak1
,
idx_as_block_begin
,
tie
(
a_block_desc_ak0_m_ak1
),
...
...
@@ -627,11 +622,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple
([
&
](
auto
)
{
return
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
);
},
Number
<
NumBTensor
>
{});
#if 0
static_assert(BBlockTransferSrcScalarPerVector == BBlockTransferDstScalarPerVector_BK1,
"Src and Dst ScalarPerVector must be the same");
#endif
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v7r2
<
ThisThreadBlock
,
BsDataType
,
...
...
@@ -649,7 +639,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
uniform_sequence_gen_t
<
NumBTensor
,
false
>
,
uniform_sequence_gen_t
<
NumBTensor
,
BThreadTransferSrcResetCoordinateAfterRun
>
,
Sequence
<
true
>>
{
bs_grid_desc_bk0_n_bk1
,
idx_bs_block_begin
,
tie
(
b_block_desc_bk0_n_bk1
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
f0759faf
...
...
@@ -257,7 +257,70 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
e_grid_desc_m_n
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
ALayout
,
typename
BLayout
,
typename
ELayout
>
__host__
__device__
static
bool
CheckTensorTransfersValidity
(
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
)
{
// Check if the vector dim is K1 or M|N
const
auto
A_vector_dim_size
=
ABlockTransferSrcVectorDim
==
2
?
KRaw
:
MRaw
;
const
auto
B_vector_dim_size
=
BBlockTransferSrcVectorDim
==
2
?
KRaw
:
NRaw
;
const
auto
E_vector_dim_size
=
NRaw
;
// check vector load for A tensor
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
if
(
!
(
A_vector_dim_size
==
KRaw
&&
A_vector_dim_size
%
ABlockTransferSrcScalarPerVector
==
0
))
return
false
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
if
(
!
(
A_vector_dim_size
==
MRaw
&&
A_vector_dim_size
%
ABlockTransferSrcScalarPerVector
==
0
))
return
false
;
}
else
{
return
false
;
}
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
{
if
(
!
(
B_vector_dim_size
==
NRaw
&&
B_vector_dim_size
%
BBlockTransferSrcScalarPerVector
==
0
))
return
false
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
if
(
!
(
B_vector_dim_size
==
KRaw
&&
B_vector_dim_size
%
BBlockTransferSrcScalarPerVector
==
0
))
return
false
;
}
else
{
return
false
;
}
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ELayout
>
)
{
if
(
!
(
E_vector_dim_size
==
NRaw
&&
E_vector_dim_size
%
CDEShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
return
false
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ELayout
>
)
{
if
(
!
(
E_vector_dim_size
==
NRaw
&&
CDEShuffleBlockTransferScalarPerVector_NPerBlock
==
1
))
return
false
;
}
else
{
return
false
;
}
return
true
;
}
template
<
typename
AGridDesc_M_K
,
typename
BGridDesc_N_K
,
typename
DsGridDesc_M_N
,
...
...
@@ -267,7 +330,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2ETileMap
&
)
[[
maybe_unused
]]
const
Block2ETileMap
&
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
...
...
@@ -285,7 +348,6 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
{
return
false
;
}
bool
valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
...
...
@@ -306,7 +368,6 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// check gridwise gemm pipeline
const
auto
num_k_loop
=
AK
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
return
false
;
...
...
@@ -938,6 +999,63 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
}
template
<
bool
HasMainKBlockLoop
,
typename
AGridDesc_MK
,
typename
BGridDesc_NK
,
typename
DsGridDesc_MN
,
typename
EGridDesc_MN
,
typename
Block2ETileMap
>
__device__
static
void
Run
(
const
void
*
__restrict__
p_a_grid_
,
const
void
*
__restrict__
p_b_grid_
,
DsGridPointer
p_ds_grid
,
void
*
__restrict__
p_e_grid_
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
,
const
AGridDesc_MK
&
a_grid_desc_m_k
,
const
BGridDesc_NK
&
b_grid_desc_n_k
,
const
DsGridDesc_MN
&
ds_grid_desc_m_n
,
const
EGridDesc_MN
&
e_grid_desc_m_n
,
const
Block2ETileMap
&
block_2_etile_map
)
{
const
auto
p_a_grid
=
reinterpret_cast
<
const
ADataType
*>
(
p_a_grid_
);
const
auto
p_b_grid
=
reinterpret_cast
<
const
BDataType
*>
(
p_b_grid_
);
const
auto
p_e_grid
=
reinterpret_cast
<
EDataType
*>
(
p_e_grid_
);
// tensor descriptors for block/thread-wise copy
const
auto
a_grid_desc_ak0_m_ak1
=
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k
);
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_MN
{}))
>
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
ds_grid_desc_mblock_mperblock_nblock_nperblock
(
j
)
=
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
[
j
]);
});
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n
);
Run
<
HasMainKBlockLoop
>
(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
View file @
f0759faf
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment