Commit 03247367 authored by rocking's avatar rocking
Browse files

Refine arg of operator()

parent 5c736bc1
...@@ -34,18 +34,9 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -34,18 +34,9 @@ float layernorm2d_fwd_(const S& s, A a)
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1; constexpr ck_tile::index_t kBlockPerCu = 1;
return ck_tile::launch_kernel(s, auto kargs = Kernel::MakeKargs(
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, a.p_x, a.p_gamma, a.p_beta, a.p_y, a.p_mean, a.p_invStd, a.epsilon, a.M, a.N);
grids,
blocks, return ck_tile::launch_kernel(
0, s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
a.p_x,
a.p_gamma,
a.p_beta,
a.p_y,
a.p_mean,
a.p_invStd,
a.epsilon,
a.M,
a.N));
} }
...@@ -314,21 +314,13 @@ struct Layernorm2dFwd ...@@ -314,21 +314,13 @@ struct Layernorm2dFwd
} }
} }
CK_TILE_DEVICE void operator()(const void* p_x, CK_TILE_DEVICE void operator()(Kargs kargs) const
const void* p_gamma,
const void* p_beta,
void* p_y,
void* p_mean,
void* p_invStd,
const ComputeDataType epsilon,
ck_tile::index_t M,
ck_tile::index_t N) const
{ {
const auto x_m_n = [&]() { const auto x_m_n = [&]() {
const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(p_x), static_cast<const XDataType*>(kargs.p_x),
make_tuple(M, N), make_tuple(kargs.M, kargs.N),
make_tuple(N, 1), make_tuple(kargs.N, 1),
number<kNPerThread>{}, number<kNPerThread>{},
number<1>{}); number<1>{});
...@@ -339,8 +331,8 @@ struct Layernorm2dFwd ...@@ -339,8 +331,8 @@ struct Layernorm2dFwd
const auto gamma_n = [&]() { const auto gamma_n = [&]() {
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(p_gamma), static_cast<const GammaDataType*>(kargs.p_gamma),
make_tuple(N), make_tuple(kargs.N),
make_tuple(1), make_tuple(1),
number<kNPerThread>{}, number<kNPerThread>{},
number<1>{}); number<1>{});
...@@ -351,8 +343,8 @@ struct Layernorm2dFwd ...@@ -351,8 +343,8 @@ struct Layernorm2dFwd
const auto beta_n = [&]() { const auto beta_n = [&]() {
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const BetaDataType*>(p_beta), static_cast<const BetaDataType*>(kargs.p_beta),
make_tuple(N), make_tuple(kargs.N),
make_tuple(1), make_tuple(1),
number<kNPerThread>{}, number<kNPerThread>{},
number<1>{}); number<1>{});
...@@ -369,12 +361,12 @@ struct Layernorm2dFwd ...@@ -369,12 +361,12 @@ struct Layernorm2dFwd
x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr); x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr);
const auto y_m_n = [&]() { const auto y_m_n = [&]() {
const auto y_dram_naive = const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
make_naive_tensor_view<address_space_enum::global>(static_cast<YDataType*>(p_y), static_cast<YDataType*>(kargs.p_y),
make_tuple(M, N), make_tuple(kargs.M, kargs.N),
make_tuple(N, 1), make_tuple(kargs.N, 1),
number<kNPerThread>{}, number<kNPerThread>{},
number<1>{}); number<1>{});
return pad_tensor_view(y_dram_naive, return pad_tensor_view(y_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
...@@ -399,7 +391,9 @@ struct Layernorm2dFwd ...@@ -399,7 +391,9 @@ struct Layernorm2dFwd
const auto mean_m = [&]() { const auto mean_m = [&]() {
const auto mean_dram_naive = const auto mean_dram_naive =
make_naive_tensor_view_packed<address_space_enum::global>( make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<MeanDataType*>(p_mean), make_tuple(M), number<1>{}); static_cast<MeanDataType*>(kargs.p_mean),
make_tuple(kargs.M),
number<1>{});
return pad_tensor_view( return pad_tensor_view(
mean_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{}); mean_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
...@@ -417,7 +411,9 @@ struct Layernorm2dFwd ...@@ -417,7 +411,9 @@ struct Layernorm2dFwd
const auto inv_std_m = [&]() { const auto inv_std_m = [&]() {
const auto inv_std_dram_naive = const auto inv_std_dram_naive =
make_naive_tensor_view_packed<address_space_enum::global>( make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<InvStdDataType*>(p_invStd), make_tuple(M), number<1>{}); static_cast<InvStdDataType*>(kargs.p_invStd),
make_tuple(kargs.M),
number<1>{});
return pad_tensor_view( return pad_tensor_view(
inv_std_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{}); inv_std_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
...@@ -437,8 +433,8 @@ struct Layernorm2dFwd ...@@ -437,8 +433,8 @@ struct Layernorm2dFwd
y_block_window, y_block_window,
mean_block_window, mean_block_window,
inv_std_block_window, inv_std_block_window,
static_cast<const ComputeDataType>(epsilon), static_cast<const ComputeDataType>(kargs.epsilon),
N); kargs.N);
} }
else else
{ {
...@@ -448,8 +444,8 @@ struct Layernorm2dFwd ...@@ -448,8 +444,8 @@ struct Layernorm2dFwd
y_block_window, y_block_window,
mean_block_window, mean_block_window,
inv_std_block_window, inv_std_block_window,
static_cast<const ComputeDataType>(epsilon), static_cast<const ComputeDataType>(kargs.epsilon),
N); kargs.N);
} }
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment