Commit 1e447622 authored by letaoqin's avatar letaoqin
Browse files

add bias

parent 951a52b2
...@@ -6,6 +6,18 @@ ...@@ -6,6 +6,18 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/stream_config.hpp" #include "ck/stream_config.hpp"
enum class ActivationType
{
Gelu = 0,
Relu,
Silu,
Swiglu,
Geglu,
Identity,
GeluNoneApproximate,
GeGluNoneApproximate,
InvalidType
};
struct GemmBiasAddArgs struct GemmBiasAddArgs
{ {
const void* mat_a; const void* mat_a;
......
...@@ -38,6 +38,49 @@ using S = ck::Sequence<Is...>; ...@@ -38,6 +38,49 @@ using S = ck::Sequence<Is...>;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
namespace ck {
namespace impl {
template <typename Activation>
struct AddActivation
{
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const float& x1) const
{
Activation{}.template operator()<float>(y, x0 + x1);
};
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const half_t& x1) const
{
float x = x0 + type_convert<float>(x1);
Activation{}.template operator()<float>(y, x);
};
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const float& x1) const
{
float result = 0;
Activation{}.template operator()<float>(result, x0 + x1);
y = type_convert<half_t>(result);
};
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
{
float result = 0;
Activation{}.template operator()<float>(result, x0 + x1);
y = type_convert<half_t>(result);
};
};
} // namespace impl
} // namespace ck
// clang-format off // clang-format off
template <typename ADataType, typename BDataType, typename DsDataType, typename CDataType> template <typename ADataType, typename BDataType, typename DsDataType, typename CDataType>
using DeviceOpInstance_64_16_16_64 = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3< using DeviceOpInstance_64_16_16_64 = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3<
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
...@@ -38,11 +39,11 @@ using DsLayout = ck::Tuple<D0Layout>; ...@@ -38,11 +39,11 @@ using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row; using ELayout = Row;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// using Add = ck::tensor_operation::element_wise::Add; using Add = ck::tensor_operation::element_wise::Add;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = Add;
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
B0DataType, B0DataType,
...@@ -50,8 +51,88 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataTy ...@@ -50,8 +51,88 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataTy
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CElementOp>; PassThrough>;
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 1e-1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 1.5e-1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 16.1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 8192.1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
...@@ -63,11 +144,6 @@ int main(int argc, char* argv[]) ...@@ -63,11 +144,6 @@ int main(int argc, char* argv[])
ck::index_t N = 16; ck::index_t N = 16;
ck::index_t K = 64; ck::index_t K = 64;
ck::index_t StrideA = K;
ck::index_t StrideB = N;
ck::index_t StrideD = 0;
ck::index_t StrideE = N;
if(argc == 1) if(argc == 1)
{ {
// use default case // use default case
...@@ -78,7 +154,7 @@ int main(int argc, char* argv[]) ...@@ -78,7 +154,7 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 11) else if(argc == 7)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -87,21 +163,21 @@ int main(int argc, char* argv[]) ...@@ -87,21 +163,21 @@ int main(int argc, char* argv[])
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
K = std::stoi(argv[6]); K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); printf("arg4 to 9: M (256x), N(128x), K(32x)m\n");
exit(0); exit(0);
} }
ck::index_t StrideA = K;
ck::index_t StrideB = N;
ck::index_t StrideD = 0;
ck::index_t StrideE = N;
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals; using namespace ck::literals;
...@@ -132,12 +208,12 @@ int main(int argc, char* argv[]) ...@@ -132,12 +208,12 @@ int main(int argc, char* argv[])
case 1: case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5}); a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5}); b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_m_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{0}); d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
break; break;
default: default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0}); a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5}); b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_m_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{0}); d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
} }
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
...@@ -183,13 +259,28 @@ int main(int argc, char* argv[]) ...@@ -183,13 +259,28 @@ int main(int argc, char* argv[])
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument = ref_gemm.MakeArgument(
a0_m_k, b0_k_n, e_m_n_host_result, AElementOp{}, BElementOp{}, CElementOp{}); a0_m_k, b0_k_n, e_m_n_host_result, AElementOp{}, BElementOp{}, PassThrough{});
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
CElementOp cde_element_op;
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), e_m_n_host_result(m, n), d0_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; return ck::utils::check_err(e_m_n_device_result,
e_m_n_host_result,
"Error: Incorrect results!",
get_rtol<EDataType>(),
get_atol<EDataType>())
? 0
: 1;
} }
return 0; return 0;
......
...@@ -33,7 +33,7 @@ struct Add ...@@ -33,7 +33,7 @@ struct Add
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const half_t& x1) const operator()<float>(float& y, const float& x0, const half_t& x1) const
{ {
y = x0 + type_convert<half_t>(x1); y = x0 + type_convert<float>(x1);
}; };
template <> template <>
......
...@@ -1077,6 +1077,7 @@ struct ConvScaleRelu ...@@ -1077,6 +1077,7 @@ struct ConvScaleRelu
float scale_out_; float scale_out_;
}; };
// support fastconvert of int8 to fp16 // support fastconvert of int8 to fp16
template <typename InputDataType, typename OutputDataType, index_t RegPackNumber> template <typename InputDataType, typename OutputDataType, index_t RegPackNumber>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment