Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1e447622
Commit
1e447622
authored
Oct 11, 2024
by
letaoqin
Browse files
add bias
parent
951a52b2
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
167 additions
and
20 deletions
+167
-20
example/66_gemm_bias_activation/gemm_bias_add.hpp
example/66_gemm_bias_activation/gemm_bias_add.hpp
+12
-0
example/66_gemm_bias_activation/gemm_bias_add_fp16.cpp
example/66_gemm_bias_activation/gemm_bias_add_fp16.cpp
+43
-0
example/66_gemm_bias_activation/gemm_bias_add_xdl_fp16.cpp
example/66_gemm_bias_activation/gemm_bias_add_xdl_fp16.cpp
+110
-19
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+1
-1
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+1
-0
No files found.
example/66_gemm_bias_activation/gemm_bias_add.hpp
View file @
1e447622
...
@@ -6,6 +6,18 @@
...
@@ -6,6 +6,18 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/stream_config.hpp"
#include "ck/stream_config.hpp"
enum
class
ActivationType
{
Gelu
=
0
,
Relu
,
Silu
,
Swiglu
,
Geglu
,
Identity
,
GeluNoneApproximate
,
GeGluNoneApproximate
,
InvalidType
};
struct
GemmBiasAddArgs
struct
GemmBiasAddArgs
{
{
const
void
*
mat_a
;
const
void
*
mat_a
;
...
...
example/66_gemm_bias_activation/gemm_bias_add_fp16.cpp
View file @
1e447622
...
@@ -38,6 +38,49 @@ using S = ck::Sequence<Is...>;
...
@@ -38,6 +38,49 @@ using S = ck::Sequence<Is...>;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
namespace
ck
{
namespace
impl
{
template
<
typename
Activation
>
struct
AddActivation
{
template
<
typename
Y
,
typename
X0
,
typename
X1
>
__host__
__device__
constexpr
void
operator
()(
Y
&
y
,
const
X0
&
x0
,
const
X1
&
x1
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
>
(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
Activation
{}.
template
operator
()
<
float
>(
y
,
x0
+
x1
);
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
>
(
float
&
y
,
const
float
&
x0
,
const
half_t
&
x1
)
const
{
float
x
=
x0
+
type_convert
<
float
>
(
x1
);
Activation
{}.
template
operator
()
<
float
>(
y
,
x
);
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
>
(
half_t
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
float
result
=
0
;
Activation
{}.
template
operator
()
<
float
>(
result
,
x0
+
x1
);
y
=
type_convert
<
half_t
>
(
result
);
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
>
(
half_t
&
y
,
const
float
&
x0
,
const
half_t
&
x1
)
const
{
float
result
=
0
;
Activation
{}.
template
operator
()
<
float
>(
result
,
x0
+
x1
);
y
=
type_convert
<
half_t
>
(
result
);
};
};
}
// namespace impl
}
// namespace ck
// clang-format off
// clang-format off
template
<
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
CDataType
>
template
<
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
CDataType
>
using
DeviceOpInstance_64_16_16_64
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultiD_Xdl_CShuffle_V3
<
using
DeviceOpInstance_64_16_16_64
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultiD_Xdl_CShuffle_V3
<
...
...
example/66_gemm_bias_activation/gemm_bias_add_xdl_fp16.cpp
View file @
1e447622
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
...
@@ -38,11 +39,11 @@ using DsLayout = ck::Tuple<D0Layout>;
...
@@ -38,11 +39,11 @@ using DsLayout = ck::Tuple<D0Layout>;
using
ELayout
=
Row
;
using
ELayout
=
Row
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
//
using Add = ck::tensor_operation::element_wise::Add;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
AElementOp
=
PassThrough
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
using
CElementOp
=
Add
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
A0DataType
,
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
A0DataType
,
B0DataType
,
B0DataType
,
...
@@ -50,8 +51,88 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataTy
...
@@ -50,8 +51,88 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataTy
AccDataType
,
AccDataType
,
AElementOp
,
AElementOp
,
BElementOp
,
BElementOp
,
CElementOp
>
;
PassThrough
>
;
template
<
typename
DataType
>
inline
__host__
__device__
constexpr
double
get_rtol
()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
float
>
)
{
return
1e-3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
double
>
)
{
return
1e-6
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
half_t
>
)
{
return
1e-3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
)
{
return
5e-2
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int32_t
>
)
{
return
1e-1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int8_t
>
)
{
return
1e-1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
1e-1
;
// 240 and 224 are acceptable
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
1.5e-1
;
// 57344 and 49152 are acceptable
}
else
{
return
1e-3
;
}
}
template
<
typename
DataType
>
inline
__host__
__device__
constexpr
double
get_atol
()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
float
>
)
{
return
1e-3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
double
>
)
{
return
1e-6
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
half_t
>
)
{
return
1e-3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
)
{
return
5e-2
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int32_t
>
)
{
return
1e-1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int8_t
>
)
{
return
1e-1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
16.1
;
// 240 and 224 are acceptable
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
8192.1
;
// 57344 and 49152 are acceptable
}
else
{
return
1e-3
;
}
}
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
bool
do_verification
=
true
;
bool
do_verification
=
true
;
...
@@ -63,11 +144,6 @@ int main(int argc, char* argv[])
...
@@ -63,11 +144,6 @@ int main(int argc, char* argv[])
ck
::
index_t
N
=
16
;
ck
::
index_t
N
=
16
;
ck
::
index_t
K
=
64
;
ck
::
index_t
K
=
64
;
ck
::
index_t
StrideA
=
K
;
ck
::
index_t
StrideB
=
N
;
ck
::
index_t
StrideD
=
0
;
ck
::
index_t
StrideE
=
N
;
if
(
argc
==
1
)
if
(
argc
==
1
)
{
{
// use default case
// use default case
...
@@ -78,7 +154,7 @@ int main(int argc, char* argv[])
...
@@ -78,7 +154,7 @@ int main(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
11
)
else
if
(
argc
==
7
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
...
@@ -87,21 +163,21 @@ int main(int argc, char* argv[])
...
@@ -87,21 +163,21 @@ int main(int argc, char* argv[])
M
=
std
::
stoi
(
argv
[
4
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
K
=
std
::
stoi
(
argv
[
6
]);
StrideA
=
std
::
stoi
(
argv
[
7
]);
StrideB
=
std
::
stoi
(
argv
[
8
]);
StrideD
=
std
::
stoi
(
argv
[
9
]);
StrideE
=
std
::
stoi
(
argv
[
10
]);
}
}
else
else
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x)
, StrideA, StrideB, StrideD, StrideE
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x)
m
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
ck
::
index_t
StrideA
=
K
;
ck
::
index_t
StrideB
=
N
;
ck
::
index_t
StrideD
=
0
;
ck
::
index_t
StrideE
=
N
;
auto
f_host_tensor_descriptor
=
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
using
namespace
ck
::
literals
;
using
namespace
ck
::
literals
;
...
@@ -132,12 +208,12 @@ int main(int argc, char* argv[])
...
@@ -132,12 +208,12 @@ int main(int argc, char* argv[])
case
1
:
case
1
:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
-
0.5
,
0.5
});
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
-
0.5
,
0.5
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
-
0.5
,
0.5
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
-
0.5
,
0.5
});
d0_m_n
.
GenerateTensorValue
(
GeneratorTensor_
1
<
D0DataType
>
{
0
});
d0_m_n
.
GenerateTensorValue
(
GeneratorTensor_
3
<
D0DataType
>
{
-
0.5
,
0.5
});
break
;
break
;
default:
default:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
0.0
,
1.0
});
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
0.0
,
1.0
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
-
0.5
,
0.5
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
-
0.5
,
0.5
});
d0_m_n
.
GenerateTensorValue
(
GeneratorTensor_
1
<
D0DataType
>
{
0
});
d0_m_n
.
GenerateTensorValue
(
GeneratorTensor_
3
<
D0DataType
>
{
-
0.5
,
0.5
});
}
}
DeviceMem
a0_device_buf
(
sizeof
(
A0DataType
)
*
a0_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a0_device_buf
(
sizeof
(
A0DataType
)
*
a0_m_k
.
mDesc
.
GetElementSpaceSize
());
...
@@ -183,13 +259,28 @@ int main(int argc, char* argv[])
...
@@ -183,13 +259,28 @@ int main(int argc, char* argv[])
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a0_m_k
,
b0_k_n
,
e_m_n_host_result
,
AElementOp
{},
BElementOp
{},
CElementOp
{});
a0_m_k
,
b0_k_n
,
e_m_n_host_result
,
AElementOp
{},
BElementOp
{},
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
CElementOp
cde_element_op
;
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
cde_element_op
(
e_m_n_host_result
(
m
,
n
),
e_m_n_host_result
(
m
,
n
),
d0_m_n
(
m
,
n
));
}
}
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
)
?
0
:
1
;
return
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
,
"Error: Incorrect results!"
,
get_rtol
<
EDataType
>
(),
get_atol
<
EDataType
>
())
?
0
:
1
;
}
}
return
0
;
return
0
;
...
...
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
1e447622
...
@@ -33,7 +33,7 @@ struct Add
...
@@ -33,7 +33,7 @@ struct Add
__host__
__device__
constexpr
void
__host__
__device__
constexpr
void
operator
()
<
float
>
(
float
&
y
,
const
float
&
x0
,
const
half_t
&
x1
)
const
operator
()
<
float
>
(
float
&
y
,
const
float
&
x0
,
const
half_t
&
x1
)
const
{
{
y
=
x0
+
type_convert
<
half_
t
>
(
x1
);
y
=
x0
+
type_convert
<
floa
t
>
(
x1
);
};
};
template
<
>
template
<
>
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
1e447622
...
@@ -1077,6 +1077,7 @@ struct ConvScaleRelu
...
@@ -1077,6 +1077,7 @@ struct ConvScaleRelu
float
scale_out_
;
float
scale_out_
;
};
};
// support fastconvert of int8 to fp16
// support fastconvert of int8 to fp16
template
<
typename
InputDataType
,
typename
OutputDataType
,
index_t
RegPackNumber
>
template
<
typename
InputDataType
,
typename
OutputDataType
,
index_t
RegPackNumber
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment