Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d8552699
Commit
d8552699
authored
Feb 22, 2023
by
Chao Liu
Browse files
fix compilation
parent
40fabdcd
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
55 additions
and
23 deletions
+55
-23
example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp
..._gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp
+6
-5
example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp
..._gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp
+2
-2
example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp32.cpp
..._gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp32.cpp
+6
-6
example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp
..._gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp
+6
-5
example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int8.cpp
..._gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int8.cpp
+6
-5
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+29
-0
No files found.
example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp
View file @
d8552699
...
@@ -7,6 +7,7 @@ using ADataType = BF16;
...
@@ -7,6 +7,7 @@ using ADataType = BF16;
using
BDataType
=
BF16
;
using
BDataType
=
BF16
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
F32
;
// C matrix doesn't exsit in GPU memory, this is used for host verification
using
D0DataType
=
BF16
;
using
D0DataType
=
BF16
;
using
D1DataType
=
BF16
;
using
D1DataType
=
BF16
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
,
D1DataType
>
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
,
D1DataType
>
;
...
@@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
...
@@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
BDataType
,
Acc
DataType
,
C
DataType
,
AccDataType
,
AccDataType
,
AElementOp
,
AElementOp
,
BElementOp
,
BElementOp
,
...
...
example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp
View file @
d8552699
...
@@ -6,8 +6,8 @@
...
@@ -6,8 +6,8 @@
using
ADataType
=
F16
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
BDataType
=
F16
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F
16
;
using
CShuffleDataType
=
F
32
;
using
CDataType
=
F
16
;
// C matrix doesn't exsit in GPU memory, this is used for host verification
using
CDataType
=
F
32
;
// C matrix doesn't exsit in GPU memory, this is used for host verification
using
D0DataType
=
F16
;
using
D0DataType
=
F16
;
using
D1DataType
=
F16
;
using
D1DataType
=
F16
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
,
D1DataType
>
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
,
D1DataType
>
;
...
...
example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp32.cpp
View file @
d8552699
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "common.hpp"
...
@@ -7,6 +6,7 @@ using ADataType = F32;
...
@@ -7,6 +6,7 @@ using ADataType = F32;
using
BDataType
=
F32
;
using
BDataType
=
F32
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
F32
;
// C matrix doesn't exsit in GPU memory, this is used for host verification
using
D0DataType
=
F32
;
using
D0DataType
=
F32
;
using
D1DataType
=
F32
;
using
D1DataType
=
F32
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
,
D1DataType
>
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
,
D1DataType
>
;
...
@@ -36,7 +36,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
...
@@ -36,7 +36,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
BDataType
,
Acc
DataType
,
C
DataType
,
AccDataType
,
AccDataType
,
AElementOp
,
AElementOp
,
BElementOp
,
BElementOp
,
...
...
example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp
View file @
d8552699
...
@@ -11,6 +11,7 @@ using ADataType = I4;
...
@@ -11,6 +11,7 @@ using ADataType = I4;
using
BDataType
=
I4
;
using
BDataType
=
I4
;
using
AccDataType
=
I32
;
using
AccDataType
=
I32
;
using
CShuffleDataType
=
I32
;
using
CShuffleDataType
=
I32
;
using
CDataType
=
I32
;
// C matrix doesn't exsit in GPU memory, this is used for host verification
using
D0DataType
=
I4
;
using
D0DataType
=
I4
;
using
D1DataType
=
I4
;
using
D1DataType
=
I4
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
,
D1DataType
>
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
,
D1DataType
>
;
...
@@ -47,7 +48,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
...
@@ -47,7 +48,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
BDataType
,
Acc
DataType
,
C
DataType
,
AccDataType
,
AccDataType
,
AElementOp
,
AElementOp
,
BElementOp
,
BElementOp
,
...
...
example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int8.cpp
View file @
d8552699
...
@@ -7,6 +7,7 @@ using ADataType = I8;
...
@@ -7,6 +7,7 @@ using ADataType = I8;
using
BDataType
=
I8
;
using
BDataType
=
I8
;
using
AccDataType
=
I32
;
using
AccDataType
=
I32
;
using
CShuffleDataType
=
I32
;
using
CShuffleDataType
=
I32
;
using
CDataType
=
I32
;
// C matrix doesn't exsit in GPU memory, this is used for host verification
using
D0DataType
=
I8
;
using
D0DataType
=
I8
;
using
D1DataType
=
I8
;
using
D1DataType
=
I8
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
,
D1DataType
>
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
,
D1DataType
>
;
...
@@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
...
@@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
BDataType
,
Acc
DataType
,
C
DataType
,
AccDataType
,
AccDataType
,
AElementOp
,
AElementOp
,
BElementOp
,
BElementOp
,
...
...
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
d8552699
...
@@ -235,6 +235,35 @@ struct AddAddFastGelu
...
@@ -235,6 +235,35 @@ struct AddAddFastGelu
e
=
type_convert
<
half_t
>
(
x1_f
);
e
=
type_convert
<
half_t
>
(
x1_f
);
}
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
bhalf_t
,
float
,
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
e
,
const
float
&
c
,
const
bhalf_t
&
d0
,
const
bhalf_t
&
d1
)
const
{
const
float
x0_f
=
c
+
type_convert
<
float
>
(
d0
)
+
type_convert
<
float
>
(
d1
);
float
x1_f
=
0
;
ck
::
tensor_operation
::
element_wise
::
FastGelu
{}.
template
operator
()
<
float
,
float
>(
x1_f
,
x0_f
);
e
=
type_convert
<
bhalf_t
>
(
x1_f
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
int8_t
,
int32_t
,
int8_t
,
int8_t
>
(
int8_t
&
e
,
const
int32_t
&
c
,
const
int8_t
&
d0
,
const
int8_t
&
d1
)
const
{
const
float
x0_f
=
type_convert
<
float
>
(
c
)
+
type_convert
<
float
>
(
d0
)
+
type_convert
<
float
>
(
d1
);
float
x1_f
=
0
;
ck
::
tensor_operation
::
element_wise
::
FastGelu
{}.
template
operator
()
<
float
,
float
>(
x1_f
,
x0_f
);
e
=
type_convert
<
int8_t
>
(
x1_f
);
}
};
};
struct
Normalize
struct
Normalize
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment