Unverified Commit 550248de authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

[layernorm] hot fix (#1620)

* hot fix ln

* some rename
parent c3a4800c
......@@ -127,9 +127,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<XScaleDataType> x_scale_host_dev({n});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host);
ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host);
ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
......@@ -212,7 +213,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
x_host.mData.cend(),
x_residual_host.mData.cbegin(),
x_host.mData.begin(),
std::plus<XDataType>{});
[](auto x_, auto r_) {
auto o_ = ck_tile::type_convert<ComputeDataType>(x_) +
ck_tile::type_convert<ComputeDataType>(r_);
return ck_tile::type_convert<XDataType>(o_);
});
}
ck_tile::reference_layernorm2d_fwd<XDataType,
GammaDataType,
......@@ -280,10 +285,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_buf.FromDevice(y_host_dev.data());
ck_tile::HostTensor<YResidualDataType> sy_host_dev({m, n}, {stride, 1});
ck_tile::HostTensor<YResidualDataType> y_residual_host_dev({m, n}, {stride, 1});
if(fused_add == 1)
{
y_residual_buf.FromDevice(sy_host_dev.data());
y_residual_buf.FromDevice(y_residual_host_dev.data());
}
auto [rtol, atol] = get_elimit<InDataType>();
......@@ -294,8 +299,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
if(fused_add == 1)
{
pass &= ck_tile::check_err(
sy_host_dev, x_host, std::string("ADD Error: Incorrect results!"), rtol, atol);
pass &= ck_tile::check_err(y_residual_host_dev,
x_host,
std::string("ADD Error: Incorrect results!"),
rtol,
atol);
}
}
else
......@@ -314,12 +322,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
atol);
if(fused_add == 1)
{
std::vector<YResidualDataType> sy_host_dev_row(
sy_host_dev.begin() + i_r * stride, sy_host_dev.begin() + i_r * stride + n);
std::vector<YResidualDataType> sy_host_ref_row(
std::vector<YResidualDataType> y_residual_host_dev_row(
y_residual_host_dev.begin() + i_r * stride,
y_residual_host_dev.begin() + i_r * stride + n);
std::vector<YResidualDataType> y_residual_host_ref_row(
x_host.begin() + i_r * stride, x_host.begin() + i_r * stride + n);
pass &= ck_tile::check_err(sy_host_dev_row,
sy_host_ref_row,
pass &= ck_tile::check_err(y_residual_host_dev_row,
y_residual_host_ref_row,
std::string("ADD[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
......
......@@ -111,8 +111,9 @@ struct Layernorm2dFwdPipelineOnePass
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) +
type_convert<YResidualDataType>(x(idx));
auto re_ = type_convert<ComputeDataType>(x_resi(idx)) +
type_convert<ComputeDataType>(x(idx));
x(idx) = type_convert<XDataType>(re_);
});
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
store_tile(y_residual_window, x);
......
......@@ -122,8 +122,9 @@ struct Layernorm2dFwdPipelineTwoPass
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) +
type_convert<YResidualDataType>(x(idx));
auto re_ = type_convert<ComputeDataType>(x_resi(idx)) +
type_convert<ComputeDataType>(x(idx));
x(idx) = type_convert<XDataType>(re_);
});
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
{
......@@ -170,8 +171,9 @@ struct Layernorm2dFwdPipelineTwoPass
{
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) +
type_convert<YResidualDataType>(x(idx));
auto re_ = type_convert<ComputeDataType>(x_resi(idx)) +
type_convert<ComputeDataType>(x(idx));
x(idx) = type_convert<XDataType>(re_);
});
}
// load gamma/beta (TODO: support no gamma/beta?)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment