Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0c51a35e
"git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "e490d68365f8f3be77f8703c4666affb9ae1cf0e"
Commit
0c51a35e
authored
Jul 25, 2023
by
aska-0096
Browse files
fpAintB kernel compile pass
parent
febd76e4
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
558 additions
and
142 deletions
+558
-142
example/49_fpAintB_gemm/common.hpp
example/49_fpAintB_gemm/common.hpp
+89
-0
example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp
example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp
+9
-3
example/49_fpAintB_gemm/run_gemm_example.inc
example/49_fpAintB_gemm/run_gemm_example.inc
+13
-1
include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp
...ensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp
+86
-24
include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
...or_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
+81
-74
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
.../tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
+75
-24
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
...or_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
+5
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+2
-2
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+10
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp
...reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp
+177
-0
library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_km_kn_mn_instance.cpp
...u/gemm/device_gemm_wmma_f16_f16_f16_km_kn_mn_instance.cpp
+4
-5
library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_km_nk_mn_instance.cpp
...u/gemm/device_gemm_wmma_f16_f16_f16_km_nk_mn_instance.cpp
+4
-5
library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp
...u/gemm/device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp
+3
-4
No files found.
example/49_fpAintB_gemm/common.hpp
0 → 100644
View file @
0c51a35e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <numeric>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp"
struct
ProblemSize
final
{
ck
::
index_t
M
=
3840
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
StrideA
=
4096
;
ck
::
index_t
StrideB
=
4096
;
ck
::
index_t
StrideC
=
4096
;
};
struct
ExecutionConfig
final
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
};
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
inline
bool
parse_cmd_args
(
int
argc
,
char
*
argv
[],
ProblemSize
&
problem_size
,
ExecutionConfig
&
config
)
{
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
10
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
problem_size
.
M
=
std
::
stoi
(
argv
[
4
]);
problem_size
.
N
=
std
::
stoi
(
argv
[
5
]);
problem_size
.
K
=
std
::
stoi
(
argv
[
6
]);
problem_size
.
StrideA
=
std
::
stoi
(
argv
[
7
]);
problem_size
.
StrideB
=
std
::
stoi
(
argv
[
8
]);
problem_size
.
StrideC
=
std
::
stoi
(
argv
[
9
]);
}
else
{
std
::
cerr
<<
"arg1: verification (0=no, 1=yes)"
<<
std
::
endl
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC"
<<
std
::
endl
;
return
false
;
}
return
true
;
}
example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp
View file @
0c51a35e
...
@@ -37,7 +37,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceFpAintBGemm_Wmma_
...
@@ -37,7 +37,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceFpAintBGemm_Wmma_
BElementOp
,
BElementOp
,
CElementOp
,
CElementOp
,
GemmDefault
,
GemmDefault
,
2
,
// Prefetch stage
1
,
// Prefetch stage
128
,
// BlockSize
128
,
// BlockSize
128
,
// MPerBlock
128
,
// MPerBlock
64
,
// NPerBlock
64
,
// NPerBlock
...
@@ -67,8 +67,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceFpAintBGemm_Wmma_
...
@@ -67,8 +67,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceFpAintBGemm_Wmma_
8
>
;
8
>
;
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferencefpAintBGemm
<
ADataType
,
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
BDataType
,
ScaleDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
#include "run_gemm_example.inc"
#include "run_gemm_example.inc"
...
...
example/49_fpAintB_gemm/run_gemm_example.inc
View file @
0c51a35e
...
@@ -27,6 +27,8 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -27,6 +27,8 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
// assume scale tensor is [1, n]
Tensor
<
ScaleDataType
>
scale_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
0
,
BLayout
{}));
switch
(
config
.
init_method
)
switch
(
config
.
init_method
)
{
{
...
@@ -34,26 +36,32 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -34,26 +36,32 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
case
1
:
case
1
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ScaleDataType
>
{
-
5.
f
,
5.
f
}(
scale_k_n
);
break
;
break
;
case
2
:
case
2
:
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistribution
<
ScaleDataType
>
{
-
1.
f
,
1.
f
}(
scale_k_n
);
break
;
break
;
case
3
:
case
3
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ScaleDataType
>
{
-
5.
f
,
5.
f
}(
scale_k_n
);
break
;
break
;
case
4
:
case
4
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
1.
f
,
1.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
1.
f
,
1.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ScaleDataType
>
{
1.
f
,
1.
f
}(
scale_k_n
);
break
;
break
;
case
5
:
case
5
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
2.
f
,
2.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
2.
f
,
2.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
2.
f
,
2.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
2.
f
,
2.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ScaleDataType
>
{
-
2.
f
,
2.
f
}(
scale_k_n
);
break
;
break
;
default
:
default
:
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistribution
<
ScaleDataType
>
{
-
1.
f
,
1.
f
}(
scale_k_n
);
}
}
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
...
@@ -61,6 +69,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -61,6 +69,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"scale_k_n: "
<<
scale_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
#ifdef BUILD_INT4_EXAMPLE
#ifdef BUILD_INT4_EXAMPLE
...
@@ -77,10 +86,12 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -77,10 +86,12 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
#else
#else
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
scale_k_n_device_buf
(
sizeof
(
ScaleDataType
)
*
scale_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
scale_k_n_device_buf
.
ToDevice
(
scale_k_n
.
mData
.
data
());
#endif
#endif
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{};
...
@@ -98,6 +109,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -98,6 +109,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
#else
#else
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ScaleDataType
*>
(
scale_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
#endif
#endif
M
,
M
,
...
@@ -136,7 +148,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -136,7 +148,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
a_m_k
,
b_k_n
,
scale_k_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
...
...
include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp
View file @
0c51a35e
...
@@ -20,7 +20,7 @@ template <index_t BlockSize,
...
@@ -20,7 +20,7 @@ template <index_t BlockSize,
typename
FloatAcc
,
typename
FloatAcc
,
typename
ABlockDesc
,
typename
ABlockDesc
,
typename
BBlockDesc
,
typename
BBlockDesc
,
typename
ScaleDesc
,
typename
Scale
Block
Desc
,
index_t
MPerBlock
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
KPerBlock
,
...
@@ -73,8 +73,9 @@ struct Blockwise_fpAintB_GemmWMMA
...
@@ -73,8 +73,9 @@ struct Blockwise_fpAintB_GemmWMMA
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I5
);
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I5
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I5
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I5
);
// As Float DataType
static
constexpr
auto
wmma_gemm
=
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
ADataType
,
B
DataType
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
>
{};
WmmaGemm
<
ADataType
,
A
DataType
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
>
{};
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
...
@@ -178,9 +179,10 @@ struct Blockwise_fpAintB_GemmWMMA
...
@@ -178,9 +179,10 @@ struct Blockwise_fpAintB_GemmWMMA
}
}
using
Tuple6
=
decltype
(
CalculateAThreadOriginDataIndex
());
using
Tuple6
=
decltype
(
CalculateAThreadOriginDataIndex
());
__host__
__device__
BlockwiseGemmWMMA
(
Tuple6
a_origin
=
CalculateAThreadOriginDataIndex
(),
__host__
__device__
Tuple6
b_origin
=
CalculateBThreadOriginDataIndex
())
Blockwise_fpAintB_GemmWMMA
(
Tuple6
a_origin
=
CalculateAThreadOriginDataIndex
(),
:
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
)
Tuple6
b_origin
=
CalculateBThreadOriginDataIndex
())
:
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
),
scale_thread_copy_
(
b_origin
)
{
{
static_assert
(
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
(),
static_assert
(
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
"wrong! Desc should be known at compile-time"
);
...
@@ -290,8 +292,12 @@ struct Blockwise_fpAintB_GemmWMMA
...
@@ -290,8 +292,12 @@ struct Blockwise_fpAintB_GemmWMMA
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static
constexpr
ABlockDesc
a_block_desc_k0_m0_m1_m2_k1
;
static
constexpr
ABlockDesc
a_block_desc_k0_m0_m1_m2_k1
;
static
constexpr
BBlockDesc
b_block_desc_k0_n0_n1_n2_k1
;
static
constexpr
BBlockDesc
b_block_desc_k0_n0_n1_n2_k1
;
static
constexpr
ScaleBlockDesc
scale_block_desc_1_n0_n1_n2_1
;
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
ScaleBlockBuffer
,
typename
CThreadBuffer
>
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
ScaleBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
const
ScaleBlockBuffer
&
scale_block_buf
,
const
ScaleBlockBuffer
&
scale_block_buf
,
...
@@ -305,8 +311,6 @@ struct Blockwise_fpAintB_GemmWMMA
...
@@ -305,8 +311,6 @@ struct Blockwise_fpAintB_GemmWMMA
scale_thread_desc_
.
GetElementSpaceSize
());
scale_thread_desc_
.
GetElementSpaceSize
());
auto
converted_b_thread_buf
=
b_thread_buf
;
auto
converted_b_thread_buf
=
b_thread_buf
;
static
constexpr
auto
dequantizer
=
Dequantizer
<
ADataType
,
BDataType
>
{};
// basic intrinsic to determine loopover direction
// basic intrinsic to determine loopover direction
if
constexpr
(
MRepeat
<
NRepeat
)
if
constexpr
(
MRepeat
<
NRepeat
)
{
{
...
@@ -333,21 +337,22 @@ struct Blockwise_fpAintB_GemmWMMA
...
@@ -333,21 +337,22 @@ struct Blockwise_fpAintB_GemmWMMA
b_thread_buf
);
b_thread_buf
);
// read weight scale
// read weight scale
scale_thread_copy_
.
Run
(
scale_thread_copy_
.
Run
(
b
_block_desc_
k0
_n0_n1_n2_
k
1
,
scale
_block_desc_
1
_n0_n1_n2_1
,
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
b
_block_buf
,
scale
_block_buf
,
b_
scale_thread_desc_
,
scale_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
b_
scale_thread_buf
);
scale_thread_buf
);
// convert B from int8 to fp16
// convert B from int8 to fp16, multiply scale
converted_b_thread_buf
=
type_convert
(
b_thread_buf
);
static_for
<
0
,
b_thread_buf
.
size
(),
1
>
{}([
&
](
auto
i
)
{
converted_b_thread_buf
(
i
)
=
// multiply scale
scale_thread_buf
[
i
/
WmmaK
]
*
dequantize
(
converted_b_thread_buf
,
scale_thread_buf
);
type_convert
<
ADataType
>
(
b_thread_buf
[
i
]);
});
vector_type
<
ADataType
,
WmmaK
>
a_thread_vec
;
vector_type
<
ADataType
,
WmmaK
>
a_thread_vec
;
vector_type
<
B
DataType
,
WmmaK
>
b_thread_vec
;
vector_type
<
A
DataType
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
ADataType
>()(
i
)
=
a_thread_vec
.
template
AsType
<
ADataType
>()(
i
)
=
...
@@ -358,7 +363,7 @@ struct Blockwise_fpAintB_GemmWMMA
...
@@ -358,7 +363,7 @@ struct Blockwise_fpAintB_GemmWMMA
(
i
/
A_K1
)
%
A_KRow
,
(
i
/
A_K1
)
%
A_KRow
,
0
,
0
,
i
%
A_K1
))
>
{}];
i
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
B
DataType
>()(
i
)
=
b_thread_vec
.
template
AsType
<
A
DataType
>()(
i
)
=
converted_b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
converted_b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
/
B_KRow
,
make_tuple
(
i
/
B_K1
/
B_KRow
,
n0
,
n0
,
...
@@ -369,7 +374,7 @@ struct Blockwise_fpAintB_GemmWMMA
...
@@ -369,7 +374,7 @@ struct Blockwise_fpAintB_GemmWMMA
});
});
using
wmma_input_type_a
=
typename
vector_type
<
ADataType
,
WmmaK
>::
type
;
using
wmma_input_type_a
=
typename
vector_type
<
ADataType
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
B
DataType
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
A
DataType
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
...
@@ -396,6 +401,20 @@ struct Blockwise_fpAintB_GemmWMMA
...
@@ -396,6 +401,20 @@ struct Blockwise_fpAintB_GemmWMMA
b_thread_desc_
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
b_thread_buf
);
// read weight scale
scale_thread_copy_
.
Run
(
scale_block_desc_1_n0_n1_n2_1
,
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
scale_block_buf
,
scale_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
scale_thread_buf
);
// convert B from int8 to fp16, multiply scale
static_for
<
0
,
b_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
converted_b_thread_buf
(
i
)
=
scale_thread_buf
[
i
/
WmmaK
]
*
type_convert
<
ADataType
>
(
b_thread_buf
[
i
]);
});
// read A
// read A
a_thread_copy_
.
Run
(
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
a_block_desc_k0_m0_m1_m2_k1
,
...
@@ -406,11 +425,11 @@ struct Blockwise_fpAintB_GemmWMMA
...
@@ -406,11 +425,11 @@ struct Blockwise_fpAintB_GemmWMMA
a_thread_buf
);
a_thread_buf
);
vector_type
<
ADataType
,
WmmaK
>
a_thread_vec
;
vector_type
<
ADataType
,
WmmaK
>
a_thread_vec
;
vector_type
<
B
DataType
,
WmmaK
>
b_thread_vec
;
vector_type
<
A
DataType
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
b_thread_vec
.
template
AsType
<
B
DataType
>()(
i
)
=
b_thread_vec
.
template
AsType
<
A
DataType
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
converted_
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
/
B_KRow
,
make_tuple
(
i
/
B_K1
/
B_KRow
,
n0
,
n0
,
0
,
0
,
...
@@ -428,7 +447,7 @@ struct Blockwise_fpAintB_GemmWMMA
...
@@ -428,7 +447,7 @@ struct Blockwise_fpAintB_GemmWMMA
});
});
using
wmma_input_type_a
=
typename
vector_type
<
ADataType
,
WmmaK
>::
type
;
using
wmma_input_type_a
=
typename
vector_type
<
ADataType
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
B
DataType
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
A
DataType
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
...
@@ -472,6 +491,15 @@ struct Blockwise_fpAintB_GemmWMMA
...
@@ -472,6 +491,15 @@ struct Blockwise_fpAintB_GemmWMMA
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
Number
<
1
>
{}));
Number
<
1
>
{}));
static
constexpr
auto
scale_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
WmmaK
/
B_K1
/
B_KRow
>
{},
Number
<
NRepeat
>
{},
I1
,
Number
<
B_KRow
>
{},
I1
,
Number
<
B_K1
>
{}),
make_tuple
(
I0
,
I1
,
I0
,
I0
,
I0
,
I0
));
// C[M, N, NumRegWMMA]
// C[M, N, NumRegWMMA]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
wmma_gemm
.
GetRegSizePerWmma
()));
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
wmma_gemm
.
GetRegSizePerWmma
()));
...
@@ -548,8 +576,42 @@ struct Blockwise_fpAintB_GemmWMMA
...
@@ -548,8 +576,42 @@ struct Blockwise_fpAintB_GemmWMMA
TransposeC
?
true
:
false
>
;
TransposeC
?
true
:
false
>
;
};
};
template
<
bool
EnableLds
>
struct
ScaleThreadCopySelector
;
template
<
>
struct
ScaleThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v4
<
ScaleDataType
,
ScaleDataType
,
decltype
(
scale_block_desc_1_n0_n1_n2_1
),
decltype
(
scale_thread_desc_
),
Sequence
<
WmmaK
/
B_K1
/
B_KRow
,
1
,
1
,
B_KRow
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
B_K1
,
B_K1
>
;
};
template
<
>
struct
ScaleThreadCopySelector
<
false
>
{
using
type
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
ScaleDataType
,
ScaleDataType
,
decltype
(
scale_block_desc_1_n0_n1_n2_1
),
decltype
(
scale_thread_desc_
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
WmmaK
/
B_K1
/
B_KRow
,
1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
1
>
;
};
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
typename
BThreadCopySelector
<
BEnableLds
>::
type
b_thread_copy_
;
typename
BThreadCopySelector
<
BEnableLds
>::
type
b_thread_copy_
;
typename
ScaleThreadCopySelector
<
BEnableLds
>::
type
scale_thread_copy_
;
};
};
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
View file @
0c51a35e
...
@@ -22,7 +22,7 @@ namespace tensor_operation {
...
@@ -22,7 +22,7 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
// 1. DequantB(K, N) = int2fp(B(K, N)) * scale(1, N)
// 1. DequantB(K, N) = int2fp(B(K, N)) * scale(1, N)
// 2. C(M, N) = A(M, K) * DequantB(K, N)
// 2. C(M, N) = A(M, K) * DequantB(K, N)
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
...
@@ -66,16 +66,16 @@ template <typename ALayout,
...
@@ -66,16 +66,16 @@ template <typename ALayout,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
ck
::
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
v1
>
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
dequant_
v1
>
struct
DeviceFpAintBGemm_Wmma_CShuffle
:
public
DeviceGemm_dequantB
<
ALayout
,
struct
DeviceFpAintBGemm_Wmma_CShuffle
:
public
DeviceGemm_dequantB
<
ALayout
,
BLayout
,
BLayout
,
CLayout
,
CLayout
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
CElementwiseOperation
>
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -103,7 +103,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -103,7 +103,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static
constexpr
auto
matrix_padder
=
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
using
DeviceOp
=
DeviceFpAintBGemm_Wmma_CShuffle
;
using
DeviceOp
=
DeviceFpAintBGemm_Wmma_CShuffle
;
// Describe how data read from Global memory
// Describe how data read from Global memory
...
@@ -183,7 +183,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -183,7 +183,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
// When K = 1, it might be scale tensor.
// When K = 1, it might be scale tensor.
assert
(
K
%
K1
==
0
&&
K
!=
1
);
assert
(
K
%
K1
==
0
&&
K
!=
1
);
if
constexpr
(
BEnableLds
)
if
constexpr
(
BEnableLds
)
{
{
...
@@ -241,61 +241,62 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -241,61 +241,62 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
using
GridwiseGemm
=
GridwiseFpAintBGemm_Wmma
<
GridwiseFpAintBGemm_Wmma
<
BlockSize
,
BlockSize
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
ScaleDataType
,
ScaleDataType
,
AccDataType
,
AccDataType
,
CShuffleDataType
,
CShuffleDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc
,
AGridDesc
,
BGridDesc
,
BGridDesc
,
CGridDesc_M_N
,
ScaleGridDesc
,
AElementwiseOperation
,
CGridDesc_M_N
,
BElementwiseOperation
,
AElementwiseOperation
,
CElementwiseOperation
,
BElementwiseOperation
,
MPerBlock
,
CElementwiseOperation
,
NPerBlock
,
MPerBlock
,
KPerBlock
,
NPerBlock
,
MPerWmma
,
KPerBlock
,
NPerWmma
,
MPerWmma
,
K1
,
NPerWmma
,
MRepeat
,
K1
,
NRepeat
,
MRepeat
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
NRepeat
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferSrcAccessOrder
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcVectorDim
,
ABlockTransferDstScalarPerVector_K1
,
ABlockTransferSrcScalarPerVector
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockTransferDstScalarPerVector_K1
,
AEnableLds
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM
,
AEnableLds
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferSrcAccessOrder
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcVectorDim
,
BBlockTransferDstScalarPerVector_K1
,
BBlockTransferSrcScalarPerVector
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferDstScalarPerVector_K1
,
BEnableLds
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
BEnableLds
,
CShuffleMRepeatPerShuffle
,
BBlockLdsAddExtraN
,
CShuffleNRepeatPerShuffle
,
CShuffleMRepeatPerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleNRepeatPerShuffle
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
NumPrefetch
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
NumPrefetch
,
PipelineVer
>
;
LoopSched
,
PipelineVer
>
;
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
const
ADataType
*
p_a_grid
,
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
const
BDataType
*
p_b_grid
,
const
ScaleDataType
*
p_scale
,
const
ScaleDataType
*
p_scale
_grid
,
CDataType
*
p_c_grid
,
CDataType
*
p_c_grid
,
index_t
M
,
index_t
M
,
index_t
N
,
index_t
N
,
...
@@ -310,7 +311,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -310,7 +311,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_scale_grid_
{
p_scale
},
p_scale_grid_
{
p_scale
_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc_
{},
a_grid_desc_
{},
b_grid_desc_
{},
b_grid_desc_
{},
...
@@ -327,10 +328,10 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -327,10 +328,10 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
NRaw_
{
N
},
NRaw_
{
N
},
KRaw_
{
K
}
KRaw_
{
K
}
{
{
a_grid_desc_
=
DeviceOp
::
MakeAGridDescriptor
(
M
,
K
,
StrideA
);
a_grid_desc_
=
DeviceOp
::
MakeAGridDescriptor
(
M
,
K
,
StrideA
);
b_grid_desc_
=
DeviceOp
::
MakeBGridDescriptor
(
K
,
N
,
StrideB
);
b_grid_desc_
=
DeviceOp
::
MakeBGridDescriptor
(
K
,
N
,
StrideB
);
scale_grid_desc_
=
DeviceOp
::
MakeBGridDescriptor
(
1
,
N
,
1
);
scale_grid_desc_
=
DeviceOp
::
MakeBGridDescriptor
(
K
,
N
,
0
);
c_grid_desc_m_n_
=
DeviceOp
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
c_grid_desc_m_n_
=
DeviceOp
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
block_2_ctile_map_
=
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
...
@@ -347,7 +348,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -347,7 +348,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
// private:
// private:
const
ADataType
*
p_a_grid_
;
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
const
ScaleDataType
*
p_
b
_grid_
;
const
ScaleDataType
*
p_
scale
_grid_
;
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc
a_grid_desc_
;
AGridDesc
a_grid_desc_
;
BGridDesc
b_grid_desc_
;
BGridDesc
b_grid_desc_
;
...
@@ -406,7 +407,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -406,7 +407,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
CDataType
,
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc
>
,
remove_reference_t
<
DeviceOp
::
AGridDesc
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc
>
,
remove_reference_t
<
DeviceOp
::
B
GridDesc
>
,
remove_reference_t
<
DeviceOp
::
Scale
GridDesc
>
,
remove_reference_t
<
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
AElementwiseOperation
,
AElementwiseOperation
,
...
@@ -422,9 +423,11 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -422,9 +423,11 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
0
,
0
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_scale_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_
,
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
scale_grid_desc_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
...
@@ -536,10 +539,8 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -536,10 +539,8 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
}
}
}
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
b_grid_desc_
,
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
}
}
// polymorphic
// polymorphic
...
@@ -550,6 +551,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -550,6 +551,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
const
BDataType
*
p_b
,
const
ScaleDataType
*
p_scale
,
CDataType
*
p_c
,
CDataType
*
p_c
,
index_t
M
,
index_t
M
,
index_t
N
,
index_t
N
,
...
@@ -563,6 +565,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -563,6 +565,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
{
{
return
Argument
{
p_a
,
return
Argument
{
p_a
,
p_b
,
p_b
,
p_scale
,
p_c
,
p_c
,
M
,
M
,
N
,
N
,
...
@@ -582,6 +585,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -582,6 +585,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
// polymorphic
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
const
void
*
p_scale
,
void
*
p_c
,
void
*
p_c
,
index_t
M
,
index_t
M
,
index_t
N
,
index_t
N
,
...
@@ -595,6 +599,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -595,6 +599,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
ScaleDataType
*>
(
p_scale
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
CDataType
*>
(
p_c
),
M
,
M
,
N
,
N
,
...
@@ -623,8 +628,10 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -623,8 +628,10 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{
{
PipelineVersion
::
v2
,
"v2"
}};
{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
},
{
PipelineVersion
::
dequant_v1
,
"dequant_v1"
}};
// clang-format off
// clang-format off
str
<<
"DeviceFpAintBGemm_Wmma_CShuffle"
str
<<
"DeviceFpAintBGemm_Wmma_CShuffle"
...
...
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
View file @
0c51a35e
...
@@ -20,9 +20,11 @@ namespace ck {
...
@@ -20,9 +20,11 @@ namespace ck {
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
ADataType
,
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
ScaleDataType
,
typename
CDataType
,
typename
CDataType
,
typename
AGridDesc
,
typename
AGridDesc
,
typename
BGridDesc
,
typename
BGridDesc
,
typename
ScaleGridDesc
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
...
@@ -33,17 +35,19 @@ __global__ void
...
@@ -33,17 +35,19 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_wmma
(
const
ADataType
*
__restrict__
p_a_grid
,
kernel_fpAintB_gemm_wmma
(
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
CDataType
*
__restrict__
p_c_grid
,
const
ScaleDataType
*
__restrict__
p_scale_grid
,
const
AGridDesc
a_grid_desc
,
CDataType
*
__restrict__
p_c_grid
,
const
BGridDesc
b_grid_desc
,
const
AGridDesc
a_grid_desc
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
BGridDesc
b_grid_desc
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ScaleGridDesc
scale_grid_desc
,
const
AElementwiseOperation
a_element_op
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
BElementwiseOperation
b_element_op
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
CElementwiseOperation
c_element_op
,
const
AElementwiseOperation
a_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
defined(__gfx1102__))
...
@@ -51,10 +55,12 @@ __global__ void
...
@@ -51,10 +55,12 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_scale_grid
,
p_c_grid
,
p_c_grid
,
p_shared
,
p_shared
,
a_grid_desc
,
a_grid_desc
,
b_grid_desc
,
b_grid_desc
,
scale_grid_desc
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -63,9 +69,11 @@ __global__ void
...
@@ -63,9 +69,11 @@ __global__ void
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
ignore
=
p_scale_grid
;
ignore
=
p_c_grid
;
ignore
=
p_c_grid
;
ignore
=
a_grid_desc
;
ignore
=
a_grid_desc
;
ignore
=
b_grid_desc
;
ignore
=
b_grid_desc
;
ignore
=
scale_grid_desc
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
...
@@ -77,12 +85,14 @@ __global__ void
...
@@ -77,12 +85,14 @@ __global__ void
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
ScaleDataType
,
typename
AccDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
CShuffleDataType
,
typename
CDataType
,
typename
CDataType
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc
,
typename
AGridDesc
,
typename
BGridDesc
,
typename
BGridDesc
,
typename
ScaleGridDesc
,
typename
CGridDesc_M_N
,
typename
CGridDesc_M_N
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
...
@@ -119,7 +129,7 @@ template <index_t BlockSize,
...
@@ -119,7 +129,7 @@ template <index_t BlockSize,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
NumGemmKPrefetchStage
=
1
,
index_t
NumGemmKPrefetchStage
=
1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
dequant_
v1
>
struct
GridwiseFpAintBGemm_Wmma
struct
GridwiseFpAintBGemm_Wmma
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -140,7 +150,12 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -140,7 +150,12 @@ struct GridwiseFpAintBGemm_Wmma
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1_dequant
<
NumGemmKPrefetchStage
,
AEnableLds
,
BEnableLds
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
,
AEnableLds
,
BEnableLds
>
())
>
;
// Describe how data store to (LDS/VGPR) buffer from Global memory
// Describe how data store to (LDS/VGPR) buffer from Global memory
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor
()
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor
()
...
@@ -237,6 +252,38 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -237,6 +252,38 @@ struct GridwiseFpAintBGemm_Wmma
return
b_block_desc
;
return
b_block_desc
;
}
}
__host__
__device__
static
constexpr
auto
MakeScaleBlockDescriptor
()
{
// Scale [1, N], all K related dimension reduce to 1
constexpr
auto
scale_block_desc
=
[
&
]()
{
if
constexpr
(
BEnableLds
)
{
// K0->N->K1 Per Block
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
I0
,
I1
,
I0
));
}
else
{
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
Number
<
NRepeat
>
{},
I1
,
Number
<
K0PerWmma
>
{},
I1
,
I1
,
K1
),
make_tuple
(
I0
,
I1
,
I0
,
I0
,
I0
,
I0
,
I0
));
}
}();
return
scale_block_desc
;
}
__host__
__device__
static
constexpr
auto
MakeABlockSliceCopyStep
()
__host__
__device__
static
constexpr
auto
MakeABlockSliceCopyStep
()
{
{
constexpr
auto
a_block_copy_step
=
[
&
]()
{
constexpr
auto
a_block_copy_step
=
[
&
]()
{
...
@@ -537,9 +584,15 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -537,9 +584,15 @@ struct GridwiseFpAintBGemm_Wmma
BEnableLds
?
math
::
integer_least_multiple
(
MakeBBlockDescriptor
().
GetElementSpaceSize
(),
BEnableLds
?
math
::
integer_least_multiple
(
MakeBBlockDescriptor
().
GetElementSpaceSize
(),
max_lds_align
)
max_lds_align
)
:
0
;
:
0
;
static
constexpr
auto
scale_block_space_size_aligned
=
BEnableLds
?
math
::
integer_least_multiple
(
MakeScaleBlockDescriptor
().
GetElementSpaceSize
(),
max_lds_align
)
:
0
;
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b_block_space_offset
=
a_block_space_size_aligned
;
static
constexpr
auto
b_block_space_offset
=
a_block_space_size_aligned
;
static
constexpr
auto
scale_block_space_offset
=
b_block_space_offset
+
b_block_space_size_aligned
;
// LDS allocation for C shuffle in LDS
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_space_size
=
static
constexpr
auto
c_shuffle_block_space_size
=
...
@@ -551,7 +604,8 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -551,7 +604,8 @@ struct GridwiseFpAintBGemm_Wmma
static
constexpr
auto
lds_size
=
static
constexpr
auto
lds_size
=
math
::
max
(
c_shuffle_block_space_size
*
sizeof
(
CShuffleDataType
),
math
::
max
(
c_shuffle_block_space_size
*
sizeof
(
CShuffleDataType
),
a_block_space_size_aligned
*
sizeof
(
ADataType
)
+
a_block_space_size_aligned
*
sizeof
(
ADataType
)
+
b_block_space_size_aligned
*
sizeof
(
BDataType
));
b_block_space_size_aligned
*
sizeof
(
BDataType
)
+
scale_block_space_size_aligned
*
sizeof
(
ScaleDataType
));
};
};
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
...
@@ -609,7 +663,7 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -609,7 +663,7 @@ struct GridwiseFpAintBGemm_Wmma
constexpr
auto
a_block_desc
=
MakeABlockDescriptor
();
constexpr
auto
a_block_desc
=
MakeABlockDescriptor
();
constexpr
auto
b_block_desc
=
MakeBBlockDescriptor
();
constexpr
auto
b_block_desc
=
MakeBBlockDescriptor
();
constexpr
auto
scale_block_desc
=
Make
B
BlockDescriptor
();
constexpr
auto
scale_block_desc
=
Make
Scale
BlockDescriptor
();
auto
a_block_trait
=
[
&
](){
auto
a_block_trait
=
[
&
](){
// A matrix blockwise copy
// A matrix blockwise copy
...
@@ -768,7 +822,7 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -768,7 +822,7 @@ struct GridwiseFpAintBGemm_Wmma
get_thread_local_1d_id
()
%
16
,
get_thread_local_1d_id
()
%
16
,
0
));
0
));
return
make_tuple
(
b_block_buf
,
b_blockwise_copy
,
scale_blockwise_copy
);
return
make_tuple
(
b_block_buf
,
b_blockwise_copy
);
}
}
};
};
...
@@ -776,13 +830,14 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -776,13 +830,14 @@ struct GridwiseFpAintBGemm_Wmma
if
constexpr
(
BEnableLds
)
if
constexpr
(
BEnableLds
)
{
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
auto
scale_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
scale_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ScaleDataType
*>
(
p_shared
)
+
SharedMemTrait
::
scale_block_space_offset
,
static_cast
<
ScaleDataType
*>
(
p_shared
)
+
SharedMemTrait
::
scale_block_space_offset
,
SharedMemTrait
::
scale_block_space_size_aligned
);
SharedMemTrait
::
scale_block_space_size_aligned
);
auto
scale_blockwise_copy
=
auto
scale_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ck
::
tensor_operation
::
e
lement
_
wise
::
PassThrough
,
BE
lementwise
Operation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
...
@@ -802,10 +857,10 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -802,10 +857,10 @@ struct GridwiseFpAintBGemm_Wmma
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
true
,
1
>
(
NumGemmKPrefetchStage
>
(
scale_grid_desc
,
scale_grid_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}
,
b_element_op
,
scale_block_desc
,
scale_block_desc
,
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
...
@@ -815,13 +870,12 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -815,13 +870,12 @@ struct GridwiseFpAintBGemm_Wmma
else
else
{
{
// Thread-wise copy
// Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1Value
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1Value
;
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
auto
scale_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ScaleDataType
>
(
auto
scale_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ScaleDataType
>
(
scale_block_desc
.
GetElementSpaceSize
());
scale_block_desc
.
GetElementSpaceSize
());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto
scale_blockwise_copy
=
auto
scale_blockwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
ScaleDataType
,
ThreadwiseTensorSliceTransfer_v2
<
ScaleDataType
,
ScaleDataType
,
ScaleDataType
,
...
@@ -894,9 +948,6 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -894,9 +948,6 @@ struct GridwiseFpAintBGemm_Wmma
// gridwise GEMM pipeline
// gridwise GEMM pipeline
const
index_t
KBlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K
/
KPerBlock
);
const
index_t
KBlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K
/
KPerBlock
);
/*
scale_blockwise_copy
*/
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc
,
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc
,
a_block_desc
,
a_block_desc
,
a_blockwise_copy
,
a_blockwise_copy
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
View file @
0c51a35e
...
@@ -12,6 +12,7 @@ enum struct PipelineVersion
...
@@ -12,6 +12,7 @@ enum struct PipelineVersion
{
{
v1
,
v1
,
v2
,
v2
,
dequant_v1
,
};
};
template
<
PipelineVersion
PipelineVer
,
template
<
PipelineVersion
PipelineVer
,
...
@@ -36,6 +37,10 @@ constexpr auto GridwiseGemmPipeline_Selector()
...
@@ -36,6 +37,10 @@ constexpr auto GridwiseGemmPipeline_Selector()
{
{
return
GridwiseGemmPipeline_v2
{};
return
GridwiseGemmPipeline_v2
{};
}
}
else
if
constexpr
(
PipelineVer
==
PipelineVersion
::
dequant_v1
)
{
return
GridwiseGemmPipeline_v1_dequant
<
NumPrefetch
,
AEnableLds
,
BEnableLds
>
{};
}
else
else
{
{
std
::
cerr
<<
"GridwiseGemmPipeline configuration is not available"
<<
std
::
endl
;
std
::
cerr
<<
"GridwiseGemmPipeline configuration is not available"
<<
std
::
endl
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
0c51a35e
...
@@ -600,9 +600,9 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
...
@@ -600,9 +600,9 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
const
BBlockTransferStep
&
b_block_copy_step
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
ScaleGridDesc
&
scale_grid_desc
,
const
ScaleGridDesc
&
scale_grid_desc
,
const
ScaleBlockDesc
&
scale_block_desc
,
const
ScaleBlockDesc
&
scale_block_desc
,
ScaleBlockTransfer
&
scale_blockwise_copy
,
const
ScaleGridBuffer
&
scale_grid_buf
,
const
ScaleGridBuffer
&
scale_grid_buf
,
ScaleBlockBuffer
&
scale_block_buf
,
ScaleBlockBuffer
&
scale_block_buf
,
ScaleBlockTransfer
&
scale_blockwise_copy
,
const
BlockwiseGemm
&
blockwise_gemm
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
index_t
num_loop
)
...
@@ -653,7 +653,7 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
...
@@ -653,7 +653,7 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
{
{
block_sync_lds
();
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
scale_block_buf
,
c_thread_buf
);
}
}
}
}
};
};
...
...
include/ck/utility/data_type.hpp
View file @
0c51a35e
...
@@ -1090,6 +1090,16 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
...
@@ -1090,6 +1090,16 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return
type_convert
<
bhalf_t
>
(
x_fp32
);
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
}
// convert int8 to fp16 via fp32
template
<
>
inline
__host__
__device__
constexpr
half_t
type_convert
<
half_t
,
int8_t
>
(
int8_t
x
)
{
// TODO: replace it with fast_converter
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
type_convert
<
half_t
>
(
x_fp32
);
}
// Declare a template function for bf16 conversion using RTN
// Declare a template function for bf16 conversion using RTN
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
bf16_convert_rtn
(
X
x
);
__host__
__device__
constexpr
Y
bf16_convert_rtn
(
X
x
);
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp
0 → 100644
View file @
0c51a35e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
ScaleDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
ReferencefpAintBGemm
:
public
device
::
BaseOperator
{
// Argument
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
Tensor
<
ADataType
>&
a_m_k
,
const
Tensor
<
BDataType
>&
b_k_n
,
const
Tensor
<
ScaleDataType
>&
scale_k_n
,
Tensor
<
CDataType
>&
c_m_n
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
:
a_m_k_
{
a_m_k
},
b_k_n_
{
b_k_n
},
scale_k_n_
{
scale_k_n
},
c_m_n_
{
c_m_n
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
}
const
Tensor
<
ADataType
>&
a_m_k_
;
const
Tensor
<
BDataType
>&
b_k_n_
;
const
Tensor
<
ScaleDataType
>&
scale_k_n_
;
Tensor
<
CDataType
>&
c_m_n_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
};
// Invoker
struct
Invoker
:
public
device
::
BaseInvoker
{
using
Argument
=
ReferencefpAintBGemm
::
Argument
;
float
Run
(
const
Argument
&
arg
)
{
auto
f_mk_kn_mn
=
[
&
](
auto
m
,
auto
n
)
{
const
int
K
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
1
];
AccDataType
v_acc
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
ADataType
v_a
;
BDataType
v_b
;
ScaleDataType
v_scale
;
ADataType
v_converted_b
;
// use PassThrough instead of ConvertBF16RTN for reference calculation
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
{
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
}
else
{
arg
.
a_element_op_
(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
}
// same for B matrix
if
constexpr
(
is_same_v
<
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
{
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
}
else
{
arg
.
b_element_op_
(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
}
// same for scale matrix
if
constexpr
(
is_same_v
<
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
{
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_scale
,
arg
.
scale_k_n_
(
k
,
n
));
}
else
{
arg
.
b_element_op_
(
v_scale
,
arg
.
scale_k_n_
(
k
,
n
));
}
v_converted_b
=
type_convert
<
ADataType
>
(
v_b
)
*
v_scale
;
v_acc
+=
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_converted_b
);
}
AccDataType
v_c
;
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_m_n_
(
m
,
n
)
=
ck
::
type_convert
<
CDataType
>
(
v_c
);
};
make_ParallelTensorFunctor
(
f_mk_kn_mn
,
arg
.
c_m_n_
.
mDesc
.
GetLengths
()[
0
],
arg
.
c_m_n_
.
mDesc
.
GetLengths
()[
1
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
static
auto
MakeArgument
(
const
Tensor
<
ADataType
>&
a_m_k
,
const
Tensor
<
BDataType
>&
b_k_n
,
const
Tensor
<
ScaleDataType
>&
scale_k_n
,
Tensor
<
CDataType
>&
c_m_n
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
a_m_k
,
b_k_n
,
scale_k_n
,
c_m_n
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
virtual
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"ReferenceGemm"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace host
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_km_kn_mn_instance.cpp
View file @
0c51a35e
...
@@ -29,9 +29,8 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...
@@ -29,9 +29,8 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using
device_gemm_wmma_f16_f16_f16_km_kn_mn_instances
=
using
device_gemm_wmma_f16_f16_f16_km_kn_mn_instances
=
std
::
tuple
<
std
::
tuple
<
// clang-format off
// clang-format off
//######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer|
//######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer|
//######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector|
//######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector|
//######################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerStore| PerStore| MBlock_MPerBlock| |
//######################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerStore| PerStore| MBlock_MPerBlock| |
...
@@ -62,8 +61,8 @@ using device_gemm_wmma_f16_f16_f16_km_kn_mn_instances =
...
@@ -62,8 +61,8 @@ using device_gemm_wmma_f16_f16_f16_km_kn_mn_instances =
// 1 Wave
// 1 Wave
DeviceGemmWmma_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
32
,
16
,
32
,
64
,
8
,
16
,
16
,
1
,
2
,
S
<
2
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
2
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
2
>
,
8
>
,
DeviceGemmWmma_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
32
,
16
,
32
,
64
,
8
,
16
,
16
,
1
,
2
,
S
<
2
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
2
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
2
>
,
8
>
,
DeviceGemmWmma_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
32
,
16
,
16
,
64
,
8
,
16
,
16
,
1
,
1
,
S
<
2
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
2
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
2
>
,
8
>
DeviceGemmWmma_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
32
,
16
,
16
,
64
,
8
,
16
,
16
,
1
,
1
,
S
<
2
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
2
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
2
>
,
8
>
// clang-format on
// clang-format on
>
;
>
;
void
add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances
(
void
add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_km_nk_mn_instance.cpp
View file @
0c51a35e
...
@@ -29,9 +29,8 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...
@@ -29,9 +29,8 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using
device_gemm_wmma_f16_f16_f16_km_nk_mn_instances
=
using
device_gemm_wmma_f16_f16_f16_km_nk_mn_instances
=
std
::
tuple
<
std
::
tuple
<
// clang-format off
// clang-format off
//######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer|
//######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer|
//######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector|
//######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector|
//######################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerStore| PerStore| MBlock_MPerBlock| |
//######################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerStore| PerStore| MBlock_MPerBlock| |
...
@@ -62,8 +61,8 @@ using device_gemm_wmma_f16_f16_f16_km_nk_mn_instances =
...
@@ -62,8 +61,8 @@ using device_gemm_wmma_f16_f16_f16_km_nk_mn_instances =
// 1 Wave
// 1 Wave
DeviceGemmWmma_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
32
,
16
,
32
,
64
,
8
,
16
,
16
,
1
,
2
,
S
<
2
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
2
>
,
8
>
,
DeviceGemmWmma_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
32
,
16
,
32
,
64
,
8
,
16
,
16
,
1
,
2
,
S
<
2
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
2
>
,
8
>
,
DeviceGemmWmma_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
32
,
16
,
16
,
64
,
8
,
16
,
16
,
1
,
1
,
S
<
2
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
2
>
,
8
>
DeviceGemmWmma_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
32
,
16
,
16
,
64
,
8
,
16
,
16
,
1
,
1
,
S
<
2
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
2
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
2
>
,
8
>
// clang-format on
// clang-format on
>
;
>
;
void
add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances
(
void
add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp
View file @
0c51a35e
...
@@ -29,9 +29,8 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...
@@ -29,9 +29,8 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances
=
using
device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances
=
std
::
tuple
<
std
::
tuple
<
// clang-format off
// clang-format off
//######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer|
//######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer|
//######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector|
//######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector|
//######################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerStore| PerStore| MBlock_MPerBlock| |
//######################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerStore| PerStore| MBlock_MPerBlock| |
...
@@ -143,7 +142,7 @@ using device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances =
...
@@ -143,7 +142,7 @@ using device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances =
DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>
DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>
#endif
#endif
// clang-format on
// clang-format on
>
;
>
;
void
add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances
(
void
add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment