"vscode:/vscode.git/clone" did not exist on "42dd5af51ec3f345018b2206a1656bb09718af67"
Commit 03247367 authored by rocking's avatar rocking
Browse files

Refine arg of operator()

parent 5c736bc1
......@@ -34,18 +34,9 @@ float layernorm2d_fwd_(const S& s, A a)
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
return ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{},
grids,
blocks,
0,
a.p_x,
a.p_gamma,
a.p_beta,
a.p_y,
a.p_mean,
a.p_invStd,
a.epsilon,
a.M,
a.N));
auto kargs = Kernel::MakeKargs(
a.p_x, a.p_gamma, a.p_beta, a.p_y, a.p_mean, a.p_invStd, a.epsilon, a.M, a.N);
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
......@@ -314,21 +314,13 @@ struct Layernorm2dFwd
}
}
CK_TILE_DEVICE void operator()(const void* p_x,
const void* p_gamma,
const void* p_beta,
void* p_y,
void* p_mean,
void* p_invStd,
const ComputeDataType epsilon,
ck_tile::index_t M,
ck_tile::index_t N) const
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
const auto x_m_n = [&]() {
const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(p_x),
make_tuple(M, N),
make_tuple(N, 1),
static_cast<const XDataType*>(kargs.p_x),
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.N, 1),
number<kNPerThread>{},
number<1>{});
......@@ -339,8 +331,8 @@ struct Layernorm2dFwd
const auto gamma_n = [&]() {
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(p_gamma),
make_tuple(N),
static_cast<const GammaDataType*>(kargs.p_gamma),
make_tuple(kargs.N),
make_tuple(1),
number<kNPerThread>{},
number<1>{});
......@@ -351,8 +343,8 @@ struct Layernorm2dFwd
const auto beta_n = [&]() {
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const BetaDataType*>(p_beta),
make_tuple(N),
static_cast<const BetaDataType*>(kargs.p_beta),
make_tuple(kargs.N),
make_tuple(1),
number<kNPerThread>{},
number<1>{});
......@@ -369,10 +361,10 @@ struct Layernorm2dFwd
x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr);
const auto y_m_n = [&]() {
const auto y_dram_naive =
make_naive_tensor_view<address_space_enum::global>(static_cast<YDataType*>(p_y),
make_tuple(M, N),
make_tuple(N, 1),
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<YDataType*>(kargs.p_y),
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.N, 1),
number<kNPerThread>{},
number<1>{});
......@@ -399,7 +391,9 @@ struct Layernorm2dFwd
const auto mean_m = [&]() {
const auto mean_dram_naive =
make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<MeanDataType*>(p_mean), make_tuple(M), number<1>{});
static_cast<MeanDataType*>(kargs.p_mean),
make_tuple(kargs.M),
number<1>{});
return pad_tensor_view(
mean_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
......@@ -417,7 +411,9 @@ struct Layernorm2dFwd
const auto inv_std_m = [&]() {
const auto inv_std_dram_naive =
make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<InvStdDataType*>(p_invStd), make_tuple(M), number<1>{});
static_cast<InvStdDataType*>(kargs.p_invStd),
make_tuple(kargs.M),
number<1>{});
return pad_tensor_view(
inv_std_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
......@@ -437,8 +433,8 @@ struct Layernorm2dFwd
y_block_window,
mean_block_window,
inv_std_block_window,
static_cast<const ComputeDataType>(epsilon),
N);
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.N);
}
else
{
......@@ -448,8 +444,8 @@ struct Layernorm2dFwd
y_block_window,
mean_block_window,
inv_std_block_window,
static_cast<const ComputeDataType>(epsilon),
N);
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.N);
}
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment