"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "164ee0e0c1f22341805a43ee6578826ad58e85eb"
Unverified Commit 8ef8a994 authored by valarLip's avatar valarLip Committed by GitHub
Browse files

[CK_TILE] add more stride for layernorm to support un-continuous Tensor (#1650)

* [CK_TILE] add more stride for layernorm to support un-continuous Tensor

* align CK coding style

* extend strides to layernrom expample

* clang-format...
parent 13332998
...@@ -25,7 +25,10 @@ auto create_args(int argc, char* argv[]) ...@@ -25,7 +25,10 @@ auto create_args(int argc, char* argv[])
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension") arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension") .insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n") .insert("x_stride", "-1", "x row_stride, if -1 then equal to n")
.insert("xr_stride", "-1", "x residule row_stride, if -1 then equal to n")
.insert("y_stride", "-1", "y row_stride, if -1 then equal to n")
.insert("yr_stride", "-1", "y residule row_stride, if -1 then equal to n")
.insert("e", "1e-5", "epsilon") .insert("e", "1e-5", "epsilon")
.insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case") .insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case")
.insert("v", "1", "cpu validation or not") .insert("v", "1", "cpu validation or not")
...@@ -54,11 +57,20 @@ template <typename InDataType, ...@@ -54,11 +57,20 @@ template <typename InDataType,
bool SaveMeanVar> bool SaveMeanVar>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
ck_tile::index_t m = arg_parser.get_int("m"); ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n"); ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride"); ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(stride < 0) if(x_stride < 0)
stride = n; x_stride = n;
ck_tile::index_t xr_stride = arg_parser.get_int("xr_stride");
if(xr_stride < 0)
xr_stride = n;
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
if(y_stride < 0)
y_stride = n;
ck_tile::index_t yr_stride = arg_parser.get_int("yr_stride");
if(yr_stride < 0)
yr_stride = n;
float epsilon = arg_parser.get_float("e"); float epsilon = arg_parser.get_float("e");
std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_o = arg_parser.get_str("prec_o"); std::string prec_o = arg_parser.get_str("prec_o");
...@@ -89,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -89,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
return false; return false;
} }
assert(stride >= n); assert(x_stride >= n);
using TypeConfig = LayerNormTypeConfig<InDataType, OutDataType, XScaleDataType, YScaleDataType>; using TypeConfig = LayerNormTypeConfig<InDataType, OutDataType, XScaleDataType, YScaleDataType>;
...@@ -108,15 +120,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -108,15 +120,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
using ComputeDataType = typename TypeConfig::ComputeDataType; using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1}); ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<GammaDataType> gamma_host({n}); ck_tile::HostTensor<GammaDataType> gamma_host({n});
ck_tile::HostTensor<BetaDataType> beta_host({n}); ck_tile::HostTensor<BetaDataType> beta_host({n});
ck_tile::HostTensor<XResidualDataType> x_residual_host({m, n}, {stride, 1}); ck_tile::HostTensor<XResidualDataType> x_residual_host({m, n}, {xr_stride, 1});
ck_tile::HostTensor<YResidualDataType> y_residual_host({m, n}, {stride, 1}); ck_tile::HostTensor<YResidualDataType> y_residual_host({m, n}, {yr_stride, 1});
ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {stride, 1}); ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {y_stride, 1});
ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {stride, 1}); ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {y_stride, 1});
ck_tile::HostTensor<MeanDataType> mean_host_ref({m}); ck_tile::HostTensor<MeanDataType> mean_host_ref({m});
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({m}); ck_tile::HostTensor<InvStdDataType> invStd_host_ref({m});
...@@ -162,7 +174,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -162,7 +174,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
}(); }();
std::cout << "[" << prec_str << "]" std::cout << "[" << prec_str << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; << " m:" << m << ", n:" << n << ", x_stride:" << x_stride
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
<< ", yr_stride:" << yr_stride << std::flush;
layernorm2d_fwd_traits traits{ layernorm2d_fwd_traits traits{
prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant}; prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant};
...@@ -182,7 +196,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -182,7 +196,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
epsilon, epsilon,
m, m,
n, n,
stride}; x_stride, // x row_stride
xr_stride, // x residule row stride
y_stride, // y row stride
yr_stride}; // y residule row stride
float ave_time = layernorm2d_fwd( float ave_time = layernorm2d_fwd(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
...@@ -285,7 +302,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -285,7 +302,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_buf.FromDevice(y_host_dev.data()); y_buf.FromDevice(y_host_dev.data());
ck_tile::HostTensor<YResidualDataType> y_residual_host_dev({m, n}, {stride, 1}); ck_tile::HostTensor<YResidualDataType> y_residual_host_dev({m, n}, {yr_stride, 1});
if(fused_add == 1) if(fused_add == 1)
{ {
y_residual_buf.FromDevice(y_residual_host_dev.data()); y_residual_buf.FromDevice(y_residual_host_dev.data());
...@@ -293,7 +310,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -293,7 +310,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto [rtol, atol] = get_elimit<InDataType>(); auto [rtol, atol] = get_elimit<InDataType>();
if(stride == n) if(x_stride == n)
{ {
pass = ck_tile::check_err( pass = ck_tile::check_err(
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
...@@ -310,10 +327,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -310,10 +327,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
for(int i_r = 0; i_r < m; i_r++) for(int i_r = 0; i_r < m; i_r++)
{ {
std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * stride, std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * y_stride,
y_host_dev.begin() + i_r * stride + n); y_host_dev.begin() + i_r * y_stride + n);
std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * stride, std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * y_stride,
y_host_ref.begin() + i_r * stride + n); y_host_ref.begin() + i_r * y_stride + n);
pass &= ck_tile::check_err(y_host_dev_row, pass &= ck_tile::check_err(y_host_dev_row,
y_host_ref_row, y_host_ref_row,
std::string("OUT[") + std::to_string(i_r) + std::string("OUT[") + std::to_string(i_r) +
...@@ -323,10 +340,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -323,10 +340,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(fused_add == 1) if(fused_add == 1)
{ {
std::vector<YResidualDataType> y_residual_host_dev_row( std::vector<YResidualDataType> y_residual_host_dev_row(
y_residual_host_dev.begin() + i_r * stride, y_residual_host_dev.begin() + i_r * yr_stride,
y_residual_host_dev.begin() + i_r * stride + n); y_residual_host_dev.begin() + i_r * yr_stride + n);
std::vector<YResidualDataType> y_residual_host_ref_row( std::vector<YResidualDataType> y_residual_host_ref_row(
x_host.begin() + i_r * stride, x_host.begin() + i_r * stride + n); x_host.begin() + i_r * yr_stride, x_host.begin() + i_r * yr_stride + n);
pass &= ck_tile::check_err(y_residual_host_dev_row, pass &= ck_tile::check_err(y_residual_host_dev_row,
y_residual_host_ref_row, y_residual_host_ref_row,
std::string("ADD[") + std::to_string(i_r) + std::string("ADD[") + std::to_string(i_r) +
......
...@@ -28,7 +28,10 @@ struct Layernorm2dFwdHostArgs ...@@ -28,7 +28,10 @@ struct Layernorm2dFwdHostArgs
index_t m; index_t m;
index_t n; index_t n;
index_t stride; // row_stride index_t x_stride; // x row_stride
index_t xr_stride; // x residule row stride
index_t y_stride; // y row stride
index_t yr_stride; // y residule row stride
}; };
// TODO: Extract some type to wrapper class // TODO: Extract some type to wrapper class
...@@ -93,7 +96,10 @@ struct Layernorm2dFwd ...@@ -93,7 +96,10 @@ struct Layernorm2dFwd
index_t m; index_t m;
index_t n; index_t n;
index_t stride; // row_stride index_t x_stride; // x row_stride
index_t xr_stride; // x residule row stride
index_t y_stride; // y row stride
index_t yr_stride; // y residule row stride
}; };
using Hargs = Layernorm2dFwdHostArgs; using Hargs = Layernorm2dFwdHostArgs;
...@@ -112,7 +118,10 @@ struct Layernorm2dFwd ...@@ -112,7 +118,10 @@ struct Layernorm2dFwd
hargs.epsilon, hargs.epsilon,
hargs.m, hargs.m,
hargs.n, hargs.n,
hargs.stride}; hargs.x_stride,
hargs.xr_stride,
hargs.y_stride,
hargs.yr_stride};
} }
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs) CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
...@@ -182,7 +191,7 @@ struct Layernorm2dFwd ...@@ -182,7 +191,7 @@ struct Layernorm2dFwd
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x), static_cast<const XDataType*>(kargs.p_x),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1), make_tuple(kargs.x_stride, 1),
number<Vector_N>{}, number<Vector_N>{},
number<1>{}); number<1>{});
...@@ -201,7 +210,7 @@ struct Layernorm2dFwd ...@@ -201,7 +210,7 @@ struct Layernorm2dFwd
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XResidualDataType*>(kargs.p_x_residual), static_cast<const XResidualDataType*>(kargs.p_x_residual),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1), make_tuple(kargs.xr_stride, 1),
number<Vector_N>{}, number<Vector_N>{},
number<1>{}); number<1>{});
...@@ -250,7 +259,7 @@ struct Layernorm2dFwd ...@@ -250,7 +259,7 @@ struct Layernorm2dFwd
auto tmp_ = make_naive_tensor_view<address_space_enum::global>( auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<YDataType*>(kargs.p_y), static_cast<YDataType*>(kargs.p_y),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1), make_tuple(kargs.y_stride, 1),
number<Vector_N>{}, number<Vector_N>{},
number<1>{}); number<1>{});
...@@ -266,7 +275,7 @@ struct Layernorm2dFwd ...@@ -266,7 +275,7 @@ struct Layernorm2dFwd
auto tmp_ = make_naive_tensor_view<address_space_enum::global>( auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<YResidualDataType*>(kargs.p_y_residual), static_cast<YResidualDataType*>(kargs.p_y_residual),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1), make_tuple(kargs.yr_stride, 1),
number<Vector_N>{}, number<Vector_N>{},
number<1>{}); number<1>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment