Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d6e053a3
Commit
d6e053a3
authored
Apr 04, 2022
by
rocking
Browse files
Part of gemm + softmax, Add gemm + reduceMax
parent
f015c776
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
283 additions
and
0 deletions
+283
-0
example/19_gemm_softmax/CMakeLists.txt
example/19_gemm_softmax/CMakeLists.txt
+1
-0
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
+281
-0
example/CMakeLists.txt
example/CMakeLists.txt
+1
-0
No files found.
example/19_gemm_softmax/CMakeLists.txt
0 → 100644
View file @
d6e053a3
add_example_executable
(
example_gemm_softmax_xdl_fp16 gemm_softmax_xdl_fp16.cpp
)
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
0 → 100644
View file @
d6e053a3
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_reduce_util.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
#include "device_reduce_blockwise.hpp"
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
CDataType
=
F16
;
using
AccDataType
=
F32
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdl_C_Shuffle
<
ADataType
,
// ADataType
BDataType
,
// BDataType
CDataType
,
// CDataType
AccDataType
,
// AccDataType
CDataType
,
// CShuffleDataType
ALayout
,
// ALayout
BLayout
,
// BLayout
CLayout
,
// CLayout
PassThrough
,
// AElementwiseOperation
PassThrough
,
// BElementwiseOperation
PassThrough
,
// CElementwiseOperation
256
,
// BlockSize
256
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
8
,
// AK1
8
,
// BK1
32
,
// MPerXDL
32
,
// NPerXDL
4
,
// MXdlPerWave
2
,
// NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferDstScalarPerVector_K1
true
,
// ABlockLdsAddExtraM
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
1
,
0
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
8
,
// BBlockTransferDstScalarPerVector_K1
true
,
// BBlockLdsAddExtraN
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
// CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
8
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
constexpr
int
Rank
=
2
;
constexpr
int
NumReduceDim
=
1
;
constexpr
ck
::
ReduceTensorOp_t
ReduceOpId
=
ck
::
ReduceTensorOp_t
::
MAX
;
constexpr
ck
::
NanPropagation_t
NanOpt
=
ck
::
NanPropagation_t
::
PROPAGATE_NAN
;
constexpr
bool
PropagateNan
=
(
NanOpt
==
ck
::
NanPropagation_t
::
NOT_PROPAGATE_NAN
)
?
false
:
true
;
// constexpr ck::ReduceTensorIndices_t IndicesOpt = ck::ReduceTensorIndices_t::NO_INDICES;
using
ReduceOperation
=
typename
ck
::
reduce_binary_operator
<
CDataType
,
ReduceOpId
>::
opType
;
using
InElementwiseOperation
=
typename
ck
::
reduce_unary_operator
<
CDataType
,
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
using
AccElementwiseOperation
=
typename
ck
::
reduce_unary_operator
<
CDataType
,
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
using
DeviceReduceInstance
=
ck
::
tensor_operation
::
device
::
DeviceReduceBlockWise
<
CDataType
,
CDataType
,
CDataType
,
Rank
,
NumReduceDim
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
PropagateNan
,
false
,
256
,
4
,
64
,
1
,
1
,
0
,
1
,
1
>
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
int
main
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
0
;
int
init_method
=
0
;
int
nrepeat
=
5
;
// GEMM shape
ck
::
index_t
M
=
3840
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
StrideA
=
4096
;
ck
::
index_t
StrideB
=
4096
;
ck
::
index_t
StrideC
=
4096
;
const
std
::
vector
<
int
>
reduceDims
{
0
};
const
std
::
vector
<
int
>
reduceInvariantDims
{
1
};
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
10
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
nrepeat
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
StrideA
=
std
::
stoi
(
argv
[
7
]);
StrideB
=
std
::
stoi
(
argv
[
8
]);
StrideC
=
std
::
stoi
(
argv
[
9
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: run kernel # of times (>1)
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
\n
"
);
exit
(
0
);
}
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
};
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
int
>
c_m_n_max
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
N
)}),
std
::
vector
<
std
::
size_t
>
({
1
}));
const
auto
i_inLengths
=
ck
::
to_int_vector
(
c_m_n
.
mDesc
.
GetLengths
());
const
auto
i_inStrides
=
ck
::
to_int_vector
(
c_m_n
.
mDesc
.
GetStrides
());
const
auto
i_outLengths
=
ck
::
to_int_vector
(
c_m_n_max
.
mDesc
.
GetLengths
());
const
auto
i_outStrides
=
ck
::
to_int_vector
(
c_m_n_max
.
mDesc
.
GetStrides
());
size_t
reduce_total_length
=
c_m_n
.
mDesc
.
GetElementSize
()
/
c_m_n_max
.
mDesc
.
GetElementSize
();
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n_max: "
<<
c_m_n_max
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
break
;
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
}
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_max_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_max
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_max_indices_dev
(
0
);
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm_invoker
=
gemm
.
MakeInvoker
();
auto
gemm_argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
if
(
!
gemm
.
IsSupportedArgument
(
gemm_argument
))
{
throw
std
::
runtime_error
(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
);
}
gemm_invoker
.
Run
(
gemm_argument
,
nrepeat
);
// do reduce max
auto
reduce
=
DeviceReduceInstance
{};
auto
wsSizeInBytes
=
reduce
.
GetWorkspaceSizeInBytes
(
i_inLengths
,
reduceDims
);
DeviceMem
ws_dev
(
wsSizeInBytes
);
auto
argument_ptr
=
reduce
.
MakeArgumentPointer
(
i_inLengths
,
i_inStrides
,
i_outLengths
,
i_outStrides
,
reduceDims
,
1
,
0
,
c_m_n_device_buf
.
GetDeviceBuffer
(),
c_m_n_max_device_buf
.
GetDeviceBuffer
(),
c_m_n_max_indices_dev
.
GetDeviceBuffer
(),
ws_dev
.
GetDeviceBuffer
(),
InElementwiseOperation
{
static_cast
<
int
>
(
reduce_total_length
)},
AccElementwiseOperation
{
static_cast
<
int
>
(
reduce_total_length
)});
if
(
!
reduce
.
IsSupportedArgument
(
argument_ptr
.
get
()))
{
std
::
cout
<<
"The runtime parameters seems not supported by the DeviceReduce instance, exiting!"
<<
std
::
endl
;
};
auto
invoker_ptr
=
reduce
.
MakeInvokerPointer
();
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
nrepeat
);
// TODO = do_verification
if
(
do_verification
)
;
return
0
;
}
example/CMakeLists.txt
View file @
d6e053a3
...
...
@@ -43,3 +43,4 @@ add_subdirectory(17_convnd_bwd_data_xdl)
add_subdirectory
(
15_grouped_gemm
)
add_subdirectory
(
16_gemm_reduce
)
add_subdirectory
(
18_batched_gemm_reduce
)
add_subdirectory
(
19_gemm_softmax
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment