Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ad00dd1f
Commit
ad00dd1f
authored
Aug 25, 2022
by
Adam Osewski
Browse files
SplitK int4 example
parent
c770444f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
165 additions
and
36 deletions
+165
-36
example/35_splitK_gemm/CMakeLists.txt
example/35_splitK_gemm/CMakeLists.txt
+13
-0
example/35_splitK_gemm/run_splitK_gemm_example.inc
example/35_splitK_gemm/run_splitK_gemm_example.inc
+60
-36
example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp
example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp
+92
-0
No files found.
example/35_splitK_gemm/CMakeLists.txt
View file @
ad00dd1f
add_custom_target
(
example_splitK_gemm_xdl
)
add_example_executable
(
example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp
)
add_example_executable
(
example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp
)
add_example_executable
(
example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp
)
add_example_executable
(
example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp
)
add_example_executable
(
example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp
)
add_example_executable
(
example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp
)
add_example_executable
(
example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp
)
add_example_executable
(
example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp
)
add_dependencies
(
example_splitK_gemm_xdl
example_splitK_gemm_xdl_fp32
example_splitK_gemm_xdl_fp16
example_splitK_gemm_xdl_bfp16
example_splitK_gemm_xdl_int8
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp
)
add_dependencies
(
example_splitK_gemm_xdl example_splitK_gemm_xdl_int4
)
endif
()
example/35_splitK_gemm/run_splitK_gemm_example.inc
View file @
ad00dd1f
...
@@ -24,6 +24,12 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
...
@@ -24,6 +24,12 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
{
{
using
namespace
ck
::
literals
;
using
namespace
ck
::
literals
;
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert
(
sizeof
(
ck
::
int4_t
)
==
sizeof
(
int8_t
));
static_assert
(
sizeof
(
ADataType
)
==
sizeof
(
KernelADataType
));
static_assert
(
sizeof
(
BDataType
)
==
sizeof
(
KernelBDataType
));
#endif
auto
&
[
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
KBatch
]
=
problem_size
;
auto
&
[
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
KBatch
]
=
problem_size
;
auto
f_host_tensor_descriptor
=
auto
f_host_tensor_descriptor
=
...
@@ -42,12 +48,11 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
...
@@ -42,12 +48,11 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_
host
_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_
device
_result
.
mDesc
<<
std
::
endl
;
switch
(
config
.
init_method
)
switch
(
config
.
init_method
)
{
{
...
@@ -69,8 +74,16 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
...
@@ -69,8 +74,16 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
#ifdef BUILD_INT4_EXAMPLE
const
Tensor
<
KernelADataType
>
a_m_k_converted
(
a_m_k
);
const
Tensor
<
KernelBDataType
>
b_k_n_converted
(
b_k_n
);
a_m_k_device_buf
.
ToDevice
(
a_m_k_converted
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n_converted
.
mData
.
data
());
#else
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
#endif
c_m_n_device_buf
.
SetZero
();
c_m_n_device_buf
.
SetZero
();
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{};
...
@@ -80,19 +93,25 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
...
@@ -80,19 +93,25 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
// do GEMM
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
#ifdef BUILD_INT4_EXAMPLE
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
M
,
static_cast
<
KernelBDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
N
,
#else
K
,
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
StrideA
,
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
StrideB
,
#endif
StrideC
,
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
a_element_op
,
M
,
b_element_op
,
N
,
c_element_op
,
K
,
KBatch
);
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
KBatch
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
{
...
@@ -101,23 +120,12 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
...
@@ -101,23 +120,12 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
return
0
;
return
0
;
}
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
bool
pass
=
true
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
if
(
config
.
do_verification
)
if
(
config
.
do_verification
)
{
{
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
...
@@ -129,6 +137,8 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
...
@@ -129,6 +137,8 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
a_m_k
,
b_k_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
...
@@ -136,19 +146,33 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
...
@@ -136,19 +146,33 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
if
(
std
::
is_same
<
CDataType
,
ck
::
half_t
>::
value
)
if
(
std
::
is_same
<
CDataType
,
ck
::
half_t
>::
value
)
{
{
return
ck
::
utils
::
check_err
(
c_m_n_device_result
.
mData
,
pass
&=
ck
::
utils
::
check_err
(
c_m_n_device_result
.
mData
,
c_m_n_host_result
.
mData
,
c_m_n_host_result
.
mData
,
"fp16 incorrect result"
,
"fp16 incorrect result"
,
3
e
-
3
,
3
e
-
3
,
1
e
-
3
);
1
e
-
3
);
}
}
else
else
{
{
return
ck
::
utils
::
check_err
(
c_m_n_device_result
.
mData
,
c_m_n_host_result
.
mData
);
pass
&=
ck
::
utils
::
check_err
(
c_m_n_device_result
.
mData
,
c_m_n_host_result
.
mData
);
}
}
}
}
return
true
;
if
(
config
.
time_kernel
)
{
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
}
return
pass
;
}
}
bool
run_splitK_gemm_example
(
int
argc
,
char
*
argv
[])
bool
run_splitK_gemm_example
(
int
argc
,
char
*
argv
[])
...
...
example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp
0 → 100644
View file @
ad00dd1f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/literals.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
ck
::
int4_t
;
using
BDataType
=
ck
::
int4_t
;
using
AccDataType
=
int32_t
;
using
CDataType
=
int32_t
;
using
KernelADataType
=
int8_t
;
using
KernelBDataType
=
int8_t
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSplitKCShuffle
// clang-format off
<
KernelADataType
,
//ADataType
KernelBDataType
,
//BDataType
CDataType
,
//EDataType
AccDataType
,
//AccDataType
ALayout
,
//ALayout
BLayout
,
//BLayout
CLayout
,
//ELayout
AElementOp
,
//AElementwiseOperation
BElementOp
,
//BElementwiseOperation
CElementOp
,
//CElementwiseOperation
GemmDefault
,
//GEMMSpecialization
256
,
// BlockSize
256
,
// MPerBlock
128
,
// NPerBlock
4
,
// KPerBlock
16
,
// K1
32
,
// MPerXdl
32
,
// NPerXdl
4
,
// MXdlPerWave
2
,
// NXdlPerWave
S
<
1
,
4
,
64
,
1
>
,
// ABlockTransfer ThreadCluster Lengths_K0_M_K1
S
<
0
,
2
,
1
,
3
>
,
// ABlockTransfer ThreadCluster ArrangeOrder
S
<
0
,
2
,
1
,
3
>
,
// ABlockTransfer SrcAccessOrder
3
,
// ABlockTransfer SrcVectorDim
16
,
// ABlockTransfer SrcScalarPerVector
16
,
// ABlockTransfer DstScalarPerVector_K1
true
,
// ABlockLdsExtraM
S
<
1
,
4
,
64
,
1
>
,
// BBlockTransfer ThreadCluster Lengths_K0_N_K1
S
<
0
,
1
,
3
,
2
>
,
// BBlockTransfer ThreadCluster ArrangeOrder
S
<
0
,
1
,
3
,
2
>
,
// BBlockTransfer SrcAccessOrder
3
,
// BBlockTransfer SrcVectorDim
16
,
// BBlockTransfer SrcScalarPerVector
16
,
// BBlockTransfer DstScalarPerVector_K1
true
,
// BBlockLdsExtraN
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CBlockTransferClusterLengths _MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
4
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
#define BUILD_INT4_EXAMPLE
#include "run_splitK_gemm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_splitK_gemm_example
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment