Commit 8a60a329 authored by Chao Liu's avatar Chao Liu
Browse files

add gemm bias add fastgelu

parent c7d59414
...@@ -25,7 +25,6 @@ using Row = ck::tensor_layout::gemm::RowMajor; ...@@ -25,7 +25,6 @@ using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
struct AddAddFastGelu struct AddAddFastGelu
{ {
...@@ -34,24 +33,19 @@ struct AddAddFastGelu ...@@ -34,24 +33,19 @@ struct AddAddFastGelu
__host__ __device__ void __host__ __device__ void
operator()(ck::half_t& y, const float& x0, const ck::half_t& x1, const ck::half_t& x2) const operator()(ck::half_t& y, const float& x0, const ck::half_t& x1, const ck::half_t& x2) const
{ {
#if 0
const float x = x0 + x1 + x2; const float x = x0 + x1 + x2;
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885)); const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float emu = exp(-u); const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1)); const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1));
y = x * cdf; y = ck::type_convert<ck::half_t>(x * cdf);
#else
const float x = x0 + x2;
y = x;
#endif
} }
}; };
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F16; using D0DataType = F16;
using D1DataType = F16; using D1DataType = F16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>; using DsDataType = ck::Tuple<D0DataType, D1DataType>;
...@@ -60,25 +54,20 @@ using EDataType = F16; ...@@ -60,25 +54,20 @@ using EDataType = F16;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Col;
using ELayout = Row; using ELayout = Row;
using AccDataType = F32;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
#if 0
using CDEElementOp = FastGelu;
#else
using CDEElementOp = AddAddFastGelu; using CDEElementOp = AddAddFastGelu;
#endif
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle
//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F32, F32, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; < ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on // clang-format on
int main(int argc, char* argv[]) int main(int argc, char* argv[])
...@@ -160,28 +149,16 @@ int main(int argc, char* argv[]) ...@@ -160,28 +149,16 @@ int main(int argc, char* argv[])
{ {
case 0: break; case 0: break;
case 1: case 1:
#if 0
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<EDataType>{-5, 5}); d0_m_n.GenerateTensorValue(GeneratorTensor_2<EDataType>{-5, 5});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<EDataType>{-5, 5}); d1_m_n.GenerateTensorValue(GeneratorTensor_2<EDataType>{-5, 5});
#else
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
d0_m_n.GenerateTensorValue(GeneratorTensor_1<EDataType>{1});
d1_m_n.GenerateTensorValue(GeneratorTensor_1<EDataType>{1});
#endif
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
#if 0
d0_m_n.GenerateTensorValue(GeneratorTensor_3<EDataType>{0.0, 1.0}); d0_m_n.GenerateTensorValue(GeneratorTensor_3<EDataType>{0.0, 1.0});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<EDataType>{0.0, 1.0}); d1_m_n.GenerateTensorValue(GeneratorTensor_3<EDataType>{0.0, 1.0});
#else
d0_m_n.GenerateTensorValue(GeneratorTensor_1<EDataType>{1});
d1_m_n.GenerateTensorValue(GeneratorTensor_1<EDataType>{1});
#endif
} }
std::cout << "a: " << a_m_k.mDesc.GetElementSpace() << std::endl; std::cout << "a: " << a_m_k.mDesc.GetElementSpace() << std::endl;
...@@ -192,16 +169,14 @@ int main(int argc, char* argv[]) ...@@ -192,16 +169,14 @@ int main(int argc, char* argv[])
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
#if 1 DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace());
DeviceMem d0_m_n_device_buf(sizeof(EDataType) * d0_m_n.mDesc.GetElementSpace()); DeviceMem d1_m_n_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpace());
#else
DeviceMem d0_m_n_device_buf(sizeof(EDataType) * d1_m_n.mDesc.GetElementSpace());
#endif
DeviceMem d1_m_n_device_buf(sizeof(EDataType) * d1_m_n.mDesc.GetElementSpace());
DeviceMem e_m_n_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace()); DeviceMem e_m_n_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data());
d0_m_n_device_buf.ToDevice(d0_m_n.mData.data());
d1_m_n_device_buf.ToDevice(d1_m_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
...@@ -237,8 +212,9 @@ int main(int argc, char* argv[]) ...@@ -237,8 +212,9 @@ int main(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; sizeof(D0DataType) * N + sizeof(D1DataType) * M * N +
sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -247,11 +223,10 @@ int main(int argc, char* argv[]) ...@@ -247,11 +223,10 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl; << gemm.GetTypeString() << std::endl;
e_m_n_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification) if(do_verification)
{ {
#if 1 e_m_n_device_buf.FromDevice(e_m_n_device_result.mData.data());
Tensor<AccDataType> c_m_n(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor<AccDataType> c_m_n(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
...@@ -276,23 +251,6 @@ int main(int argc, char* argv[]) ...@@ -276,23 +251,6 @@ int main(int argc, char* argv[])
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n));
} }
} }
#else
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, e_m_n_host_result, a_element_op, b_element_op, CDEElementOp{});
ref_invoker.Run(ref_argument);
#endif
return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1; return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1;
} }
......
...@@ -146,6 +146,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType: ...@@ -146,6 +146,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{ {
...@@ -575,12 +576,20 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType: ...@@ -575,12 +576,20 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
<< arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[I0].GetLength(I0) << arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[I0].GetLength(I0)
<< ", " << ", "
<< arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[I0].GetLength(I1) << arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[I0].GetLength(I1)
<< ", "
<< arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[I0].GetLength(I2)
<< ", "
<< arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[I0].GetLength(I3)
<< "}" << std::endl; << "}" << std::endl;
std::cout << "arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_{ " std::cout << "arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_{ "
<< arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[I1].GetLength(I0) << arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[I1].GetLength(I0)
<< ", " << ", "
<< arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[I1].GetLength(I1) << arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[I1].GetLength(I1)
<< ", "
<< arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[I1].GetLength(I2)
<< ", "
<< arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[I1].GetLength(I3)
<< "}" << std::endl; << "}" << std::endl;
std::cout << "p_ds_grid{ " << arg.p_ds_grid_[I0] << ", " << arg.p_ds_grid_[I1] std::cout << "p_ds_grid{ " << arg.p_ds_grid_[I0] << ", " << arg.p_ds_grid_[I1]
......
...@@ -660,14 +660,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -660,14 +660,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
if constexpr(access_id < num_access - 1) if constexpr(access_id < num_access - 1)
{ {
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
#if 1
// move on Ds // move on Ds
c_shuffle_block_copy_lds_to_global.MoveSrc1SliceWindow( c_shuffle_block_copy_lds_to_global.MoveSrc1SliceWindow(
ds_grid_desc_mblock_mperblock_nblock_nperblock[I0], c_global_step); ds_grid_desc_mblock_mperblock_nblock_nperblock[I0], c_global_step);
c_shuffle_block_copy_lds_to_global.MoveSrc2SliceWindow( c_shuffle_block_copy_lds_to_global.MoveSrc2SliceWindow(
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1], c_global_step); ds_grid_desc_mblock_mperblock_nblock_nperblock[I1], c_global_step);
#endif
// move on C // move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment