Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8a60a329
Commit
8a60a329
authored
Jun 11, 2022
by
Chao Liu
Browse files
add gemm bias add fastgelu
parent
c7d59414
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
75 deletions
+41
-75
example/03_gemm_bias_fastgelu/gemm_bias_fastgelu_xdl_fp16.cpp
...ple/03_gemm_bias_fastgelu/gemm_bias_fastgelu_xdl_fp16.cpp
+31
-73
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
+9
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+1
-2
No files found.
example/03_gemm_bias_fastgelu/gemm_bias_fastgelu_xdl_fp16.cpp
View file @
8a60a329
...
...
@@ -25,7 +25,6 @@ using Row = ck::tensor_layout::gemm::RowMajor;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
FastGelu
=
ck
::
tensor_operation
::
element_wise
::
FastGelu
;
struct
AddAddFastGelu
{
...
...
@@ -34,51 +33,41 @@ struct AddAddFastGelu
__host__
__device__
void
operator
()(
ck
::
half_t
&
y
,
const
float
&
x0
,
const
ck
::
half_t
&
x1
,
const
ck
::
half_t
&
x2
)
const
{
#if 0
const
float
x
=
x0
+
x1
+
x2
;
const
float
u
=
float
(
2
)
*
x
*
(
float
(
0.035677
)
*
x
*
x
+
float
(
0.797885
));
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
float
(
0.5
)
+
float
(
0.5
)
*
(
float
(
2
)
/
(
float
(
1
)
+
emu
)
-
float
(
1
));
y = x * cdf;
#else
const
float
x
=
x0
+
x2
;
y
=
x
;
#endif
y
=
ck
::
type_convert
<
ck
::
half_t
>
(
x
*
cdf
);
}
};
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
AccDataType
=
F32
;
using
D0DataType
=
F16
;
using
D1DataType
=
F16
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
,
D1DataType
>
;
using
EDataType
=
F16
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
ELayout
=
Row
;
using
AccDataType
=
F32
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
#if 0
using CDEElementOp = FastGelu;
#else
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
D0DataType
=
F16
;
using
D1DataType
=
F16
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
,
D1DataType
>
;
using
EDataType
=
F16
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
ELayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
AddAddFastGelu
;
#endif
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD_Xdl_CShuffle
//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B|
C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | |
|
|
|
| | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | |
|
|
|
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F32
,
DsDataType
,
F16
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
//######| ALayout| BLayout| ELayout|
AData|
BData|
AccData|
CShuffle| DsData|
EData| A| B| C
DE
| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | |
Type|
Type|
Type|
DataType| Type|
Type| Elementwise| Elementwise|
Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | |
|
|
|
| |
| Operation| Operation|
Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | |
|
|
|
| |
| | |
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// clang-format on
int
main
(
int
argc
,
char
*
argv
[])
...
...
@@ -160,28 +149,16 @@ int main(int argc, char* argv[])
{
case
0
:
break
;
case
1
:
#if 0
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
d0_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
EDataType
>
{
-
5
,
5
});
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
EDataType
>
{
-
5
,
5
});
#else
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1
});
d0_m_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
EDataType
>
{
1
});
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
EDataType
>
{
1
});
#endif
break
;
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
#if 0
d0_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
EDataType
>
{
0.0
,
1.0
});
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
EDataType
>
{
0.0
,
1.0
});
#else
d0_m_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
EDataType
>
{
1
});
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
EDataType
>
{
1
});
#endif
}
std
::
cout
<<
"a: "
<<
a_m_k
.
mDesc
.
GetElementSpace
()
<<
std
::
endl
;
...
...
@@ -192,16 +169,14 @@ int main(int argc, char* argv[])
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
#if 1
DeviceMem
d0_m_n_device_buf
(
sizeof
(
EDataType
)
*
d0_m_n
.
mDesc
.
GetElementSpace
());
#else
DeviceMem
d0_m_n_device_buf
(
sizeof
(
EDataType
)
*
d1_m_n
.
mDesc
.
GetElementSpace
());
#endif
DeviceMem
d1_m_n_device_buf
(
sizeof
(
EDataType
)
*
d1_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
d0_m_n_device_buf
(
sizeof
(
D0DataType
)
*
d0_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
d1_m_n_device_buf
(
sizeof
(
D1DataType
)
*
d1_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
e_m_n_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpace
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
d0_m_n_device_buf
.
ToDevice
(
d0_m_n
.
mData
.
data
());
d1_m_n_device_buf
.
ToDevice
(
d1_m_n
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
...
...
@@ -236,9 +211,10 @@ int main(int argc, char* argv[])
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
EDataType
)
*
M
*
N
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
D0DataType
)
*
N
+
sizeof
(
D1DataType
)
*
M
*
N
+
sizeof
(
EDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
...
@@ -247,11 +223,10 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
e_m_n_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
if
(
do_verification
)
{
#if 1
e_m_n_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
Tensor
<
AccDataType
>
c_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
...
...
@@ -276,23 +251,6 @@ int main(int argc, char* argv[])
cde_element_op
(
e_m_n_host_result
(
m
,
n
),
c_m_n
(
m
,
n
),
d0_m_n
(
m
,
n
),
d1_m_n
(
m
,
n
));
}
}
#else
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
EDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
e_m_n_host_result
,
a_element_op
,
b_element_op
,
CDEElementOp
{});
ref_invoker
.
Run
(
ref_argument
);
#endif
return
ck
::
utils
::
check_err
(
e_m_n_device_result
.
mData
,
e_m_n_host_result
.
mData
)
?
0
:
1
;
}
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
8a60a329
...
...
@@ -146,6 +146,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
...
...
@@ -575,12 +576,20 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
<<
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
I0
].
GetLength
(
I0
)
<<
", "
<<
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
I0
].
GetLength
(
I1
)
<<
", "
<<
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
I0
].
GetLength
(
I2
)
<<
", "
<<
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
I0
].
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_{ "
<<
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
I1
].
GetLength
(
I0
)
<<
", "
<<
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
I1
].
GetLength
(
I1
)
<<
", "
<<
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
I1
].
GetLength
(
I2
)
<<
", "
<<
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
I1
].
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"p_ds_grid{ "
<<
arg
.
p_ds_grid_
[
I0
]
<<
", "
<<
arg
.
p_ds_grid_
[
I1
]
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
8a60a329
...
...
@@ -660,14 +660,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
#if 1
// move on Ds
c_shuffle_block_copy_lds_to_global
.
MoveSrc1SliceWindow
(
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I0
],
c_global_step
);
c_shuffle_block_copy_lds_to_global
.
MoveSrc2SliceWindow
(
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I1
],
c_global_step
);
#endif
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment