Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7cd48ef1
Commit
7cd48ef1
authored
Apr 21, 2022
by
Chao Liu
Browse files
refactor
parent
96c73d70
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
41 deletions
+27
-41
example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp
...quant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp
+22
-8
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+0
-31
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+5
-2
No files found.
example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp
View file @
7cd48ef1
...
...
@@ -19,22 +19,36 @@
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
struct
RequantReluRequant
{
// FIXME: We just need one scale for Relu / Leaky Relu / PRelu
RequantReluRequant
(
float
scaleGemm
,
float
scaleRelu
)
:
scaleGemm_
(
scaleGemm
),
scaleRelu_
(
scaleRelu
)
{
}
using
F32
=
float
;
__host__
__device__
constexpr
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
float
gemm_requant
=
scaleGemm_
*
x
;
float
relu
=
gemm_requant
>
0
?
gemm_requant
:
0
;
float
relu_requant
=
scaleRelu_
*
relu
;
y
=
relu_requant
>
127
?
127
:
relu_requant
<
-
128
?
-
128
:
relu_requant
;
}
float
scaleGemm_
;
float
scaleRelu_
;
};
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
RequantReluRequant
=
ck
::
tensor_operation
::
element_wise
::
RequantReluRequant
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
CDataType
=
int8_t
;
using
AccDataType
=
int32_t
;
using
CShuffleDataType
=
int32_
t
;
using
CShuffleDataType
=
floa
t
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
...
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
7cd48ef1
...
...
@@ -143,37 +143,6 @@ struct AddHardswishAdd
}
};
struct
RequantReluRequant
{
// FIXME: We just need one scale for Relu / Leaky Relu / PRelu
RequantReluRequant
(
float
scaleGemm
,
float
scaleRelu
)
:
scaleGemm_
(
scaleGemm
),
scaleRelu_
(
scaleRelu
)
{
}
__host__
__device__
constexpr
void
operator
()(
int8_t
&
y
,
const
int
&
x
)
const
{
float
gemm_requant
=
scaleGemm_
*
static_cast
<
float
>
(
x
);
float
relu
=
gemm_requant
>
0
?
gemm_requant
:
0
;
float
relu_requant
=
scaleRelu_
*
relu
;
y
=
static_cast
<
int8_t
>
(
relu_requant
>
127
?
127
:
relu_requant
<
-
128
?
-
128
:
relu_requant
);
}
// for reference_gemm
__host__
__device__
constexpr
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
float
gemm_requant
=
scaleGemm_
*
x
;
float
relu
=
gemm_requant
>
0
?
gemm_requant
:
0
;
float
relu_requant
=
scaleRelu_
*
relu
;
y
=
static_cast
<
float
>
(
relu_requant
>
127
?
127
:
relu_requant
<
-
128
?
-
128
:
relu_requant
);
}
float
scaleGemm_
;
float
scaleRelu_
;
};
// Unary operators are usually called element-wisely before/after the reduction is executed on the
// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
...
...
library/include/ck/library/utility/check_err.hpp
View file @
7cd48ef1
...
...
@@ -171,9 +171,12 @@ check_err(const std::vector<T>& out,
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
if
(
out
[
i
]
!=
ref
[
i
])
const
int64_t
out_v
=
static_cast
<
int64_t
>
(
out
[
i
]);
const
int64_t
ref_v
=
static_cast
<
int64_t
>
(
ref
[
i
]);
if
(
out_v
!=
ref_v
)
{
std
::
cout
<<
"out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
out
[
i
]
<<
" != "
<<
ref
[
i
]
std
::
cout
<<
"out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
out
_v
<<
" != "
<<
ref
_v
<<
std
::
endl
<<
msg
<<
std
::
endl
;
return
false
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment