Commit 7fd5e9f5 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent e09f6e02
......@@ -26,19 +26,26 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// E = FastGelu((A * B) + D0 + D1)
// C = A * B
struct AddAddFastGelu
{
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ void
operator()(ck::half_t& y, const float& x0, const ck::half_t& x1, const ck::half_t& x2) const
operator()(ck::half_t& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
{
const float x = x0 + x1 + x2;
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
const auto fast_gelu = [&](float x) {
const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885));
const float emu = exp(-u);
const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1));
return x * cdf;
};
const float y = fast_gelu(c + float(d0) + float(d1));
y = ck::type_convert<ck::half_t>(x * cdf);
e = ck::type_convert<ck::half_t>(y);
}
};
......
......@@ -567,8 +567,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
Sequence<true, false, false>, // bool ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // bool ThreadTransferDstResetCoordinateAfterRunFlags
Sequence<true, false, false>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_descs,
make_tuple(make_multi_index(0, 0, 0, 0),
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
......@@ -626,17 +626,20 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
if constexpr(access_id < num_access - 1)
{
constexpr auto e_global_step = sfc_cde_block.GetForwardStep(access_id);
constexpr auto cde_lds_and_global_step =
sfc_cde_block.GetForwardStep(access_id);
// move on Ds
static_for<0, DsDataType::Size(), 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_descs, i + I1, e_global_step);
c_ds_descs, i + I1, cde_lds_and_global_step);
});
// move on E
cde_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), I0, e_global_step);
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0,
cde_lds_and_global_step);
}
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment