Unverified Commit 550248de authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

[layernorm] hot fix (#1620)

* hot fix ln

* some rename
parent c3a4800c
...@@ -127,9 +127,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -127,9 +127,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<XScaleDataType> x_scale_host_dev({n}); ck_tile::HostTensor<XScaleDataType> x_scale_host_dev({n});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host);
ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host); ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host); ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host);
ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
...@@ -212,7 +213,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -212,7 +213,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
x_host.mData.cend(), x_host.mData.cend(),
x_residual_host.mData.cbegin(), x_residual_host.mData.cbegin(),
x_host.mData.begin(), x_host.mData.begin(),
std::plus<XDataType>{}); [](auto x_, auto r_) {
auto o_ = ck_tile::type_convert<ComputeDataType>(x_) +
ck_tile::type_convert<ComputeDataType>(r_);
return ck_tile::type_convert<XDataType>(o_);
});
} }
ck_tile::reference_layernorm2d_fwd<XDataType, ck_tile::reference_layernorm2d_fwd<XDataType,
GammaDataType, GammaDataType,
...@@ -280,10 +285,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -280,10 +285,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_buf.FromDevice(y_host_dev.data()); y_buf.FromDevice(y_host_dev.data());
ck_tile::HostTensor<YResidualDataType> sy_host_dev({m, n}, {stride, 1}); ck_tile::HostTensor<YResidualDataType> y_residual_host_dev({m, n}, {stride, 1});
if(fused_add == 1) if(fused_add == 1)
{ {
y_residual_buf.FromDevice(sy_host_dev.data()); y_residual_buf.FromDevice(y_residual_host_dev.data());
} }
auto [rtol, atol] = get_elimit<InDataType>(); auto [rtol, atol] = get_elimit<InDataType>();
...@@ -294,8 +299,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -294,8 +299,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
if(fused_add == 1) if(fused_add == 1)
{ {
pass &= ck_tile::check_err( pass &= ck_tile::check_err(y_residual_host_dev,
sy_host_dev, x_host, std::string("ADD Error: Incorrect results!"), rtol, atol); x_host,
std::string("ADD Error: Incorrect results!"),
rtol,
atol);
} }
} }
else else
...@@ -314,12 +322,13 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -314,12 +322,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
atol); atol);
if(fused_add == 1) if(fused_add == 1)
{ {
std::vector<YResidualDataType> sy_host_dev_row( std::vector<YResidualDataType> y_residual_host_dev_row(
sy_host_dev.begin() + i_r * stride, sy_host_dev.begin() + i_r * stride + n); y_residual_host_dev.begin() + i_r * stride,
std::vector<YResidualDataType> sy_host_ref_row( y_residual_host_dev.begin() + i_r * stride + n);
std::vector<YResidualDataType> y_residual_host_ref_row(
x_host.begin() + i_r * stride, x_host.begin() + i_r * stride + n); x_host.begin() + i_r * stride, x_host.begin() + i_r * stride + n);
pass &= ck_tile::check_err(sy_host_dev_row, pass &= ck_tile::check_err(y_residual_host_dev_row,
sy_host_ref_row, y_residual_host_ref_row,
std::string("ADD[") + std::to_string(i_r) + std::string("ADD[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"), std::string("] Error: Incorrect results!"),
rtol, rtol,
......
...@@ -111,8 +111,9 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -111,8 +111,9 @@ struct Layernorm2dFwdPipelineOnePass
{ {
sweep_tile(x_resi, [&](auto idx) { sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x // compute x = x_resi + x
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) + auto re_ = type_convert<ComputeDataType>(x_resi(idx)) +
type_convert<YResidualDataType>(x(idx)); type_convert<ComputeDataType>(x(idx));
x(idx) = type_convert<XDataType>(re_);
}); });
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE) if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
store_tile(y_residual_window, x); store_tile(y_residual_window, x);
......
...@@ -122,8 +122,9 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -122,8 +122,9 @@ struct Layernorm2dFwdPipelineTwoPass
{ {
sweep_tile(x_resi, [&](auto idx) { sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x // compute x = x_resi + x
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) + auto re_ = type_convert<ComputeDataType>(x_resi(idx)) +
type_convert<YResidualDataType>(x(idx)); type_convert<ComputeDataType>(x(idx));
x(idx) = type_convert<XDataType>(re_);
}); });
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE) if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
{ {
...@@ -170,8 +171,9 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -170,8 +171,9 @@ struct Layernorm2dFwdPipelineTwoPass
{ {
sweep_tile(x_resi, [&](auto idx) { sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x // compute x = x_resi + x
x(idx) = type_convert<YResidualDataType>(x_resi(idx)) + auto re_ = type_convert<ComputeDataType>(x_resi(idx)) +
type_convert<YResidualDataType>(x(idx)); type_convert<ComputeDataType>(x(idx));
x(idx) = type_convert<XDataType>(re_);
}); });
} }
// load gamma/beta (TODO: support no gamma/beta?) // load gamma/beta (TODO: support no gamma/beta?)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment