Unverified Commit 0d0150db authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

bf16A_Int8B with fastgelu/bias (#1264)

* changed the copy function to v7r2

* adding multi_abd

* in-progress

* add post-load oob check

* debugging

* adjust instances

* add run_lds

* add elemntwise_op

* replace multi_abd_device with v3

* clean up

* clean

* clean

* Added LDSType

* profiling

* adjust oobcheck

* add missing file

* refactor

* clean

* add examples
parent b4032629
...@@ -40,23 +40,10 @@ inline constexpr bool is_pointer_v = std::is_pointer<T>::value; ...@@ -40,23 +40,10 @@ inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false> template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
__host__ __device__ constexpr Y bit_cast(const X& x) __host__ __device__ constexpr Y bit_cast(const X& x)
{ {
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST static_assert(__has_builtin(__builtin_bit_cast), "");
Y y; static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type");
// auto t = reinterpret_cast<const Y*>(&x); return __builtin_bit_cast(Y, x);
// y = *t;
__builtin_memcpy(&y, &x, sizeof(X));
return y;
#else
union AsType
{
X x;
Y y;
};
return AsType{x}.y;
#endif
} }
} // namespace ck } // namespace ck
...@@ -91,23 +91,26 @@ using GK_Tuple = ck::Tuple<G_K>; ...@@ -91,23 +91,26 @@ using GK_Tuple = ck::Tuple<G_K>;
using GK_GK_Tuple = ck::Tuple<G_K, G_K>; using GK_GK_Tuple = ck::Tuple<G_K, G_K>;
// pointwise functor // pointwise functor
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Relu = ck::tensor_operation::element_wise::Relu; using Relu = ck::tensor_operation::element_wise::Relu;
using TanH = ck::tensor_operation::element_wise::TanH; using TanH = ck::tensor_operation::element_wise::TanH;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
using Bilinear = ck::tensor_operation::element_wise::Bilinear; using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using AddRelu = ck::tensor_operation::element_wise::AddRelu; using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
using AddSilu = ck::tensor_operation::element_wise::AddSilu; using AddRelu = ck::tensor_operation::element_wise::AddRelu;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; using AddSilu = ck::tensor_operation::element_wise::AddSilu;
using FastGelu = ck::tensor_operation::element_wise::FastGelu; using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply; using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd; using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
using Gelu = ck::tensor_operation::element_wise::Gelu; using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
using Swish = ck::tensor_operation::element_wise::Swish; using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
using Add = ck::tensor_operation::element_wise::Add; using Gelu = ck::tensor_operation::element_wise::Gelu;
using Swish = ck::tensor_operation::element_wise::Swish;
using Add = ck::tensor_operation::element_wise::Add;
using Multiply = ck::tensor_operation::element_wise::Multiply;
template <typename Activation> template <typename Activation>
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>; using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;
......
...@@ -4,7 +4,8 @@ set(GEMM_MULTI_ABD_INSTANCES) ...@@ -4,7 +4,8 @@ set(GEMM_MULTI_ABD_INSTANCES)
list(APPEND GEMM_MULTI_ABD_INSTANCES list(APPEND GEMM_MULTI_ABD_INSTANCES
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_km_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
) )
add_instance_library(device_gemm_multi_abd_instance ${GEMM_MULTI_ABD_INSTANCES}) add_instance_library(device_gemm_multi_abd_instance ${GEMM_MULTI_ABD_INSTANCES})
...@@ -47,14 +47,14 @@ using D0Layout = Row; ...@@ -47,14 +47,14 @@ using D0Layout = Row;
// using DsLayout = ck::Tuple<Row>; // using DsLayout = ck::Tuple<Row>;
using ELayout = Row; using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales; using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using Add = ck::tensor_operation::element_wise::Add; using Add = ck::tensor_operation::element_wise::Add;
using FastGelu = ck::tensor_operation::element_wise::FastGelu; using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = Scales; using BElementOp = Multiply;
// using CDEElementOp = AddFastGelu; // using CDEElementOp = AddFastGelu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment