"include/conv_common.hpp" did not exist on "88b77181aab1198b41b612f6d03b6dfb2d32bd40"
Commit eb6405ee authored by rocking's avatar rocking
Browse files

Sync the naming

parent e9a41755
......@@ -32,13 +32,13 @@ template <typename XDataType,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t XSrcVectorDim,
index_t XSrcVectorSize,
index_t GammaSrcVectorDim,
index_t GammaSrcVectorSize,
index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize,
index_t OutDstVectorSize>
index_t YDstVectorSize>
struct DeviceLayernorm : public BaseOperator
{
static_assert(
......@@ -71,36 +71,36 @@ struct DeviceLayernorm : public BaseOperator
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
1>; // OutDstVectorSize
XSrcVectorDim,
XSrcVectorSize,
1>; // YDstVectorSize
using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1));
using GridwiseReduce = GridwiseLayernorm_mk_to_mk<XDataType,
GammaDataType,
BetaDataType,
YDataType,
AccDataType,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
OutDstVectorSize,
false>;
using GridwiseReduceLayernormGeneric = GridwiseLayernorm_mk_to_mk<XDataType,
GammaDataType,
BetaDataType,
YDataType,
AccDataType,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
YDstVectorSize,
false>;
struct Argument : public Reduction::Argument
{
Argument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
Argument(const std::vector<index_t> lengths,
const std::vector<index_t> xStrides,
const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides,
const std::vector<index_t> reduceDims,
......@@ -109,8 +109,8 @@ struct DeviceLayernorm : public BaseOperator
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
YDataType* p_y)
: Reduction::Argument(inLengths,
inStrides,
: Reduction::Argument(lengths,
xStrides,
{},
{},
reduceDims,
......@@ -142,16 +142,16 @@ struct DeviceLayernorm : public BaseOperator
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
const auto x_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto gamma_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.gammaStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto beta_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.betaStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
const auto y_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto kernel_main = kernel_layernorm<GridwiseReduce,
const auto kernel_main = kernel_layernorm<GridwiseReduceLayernormGeneric,
XDataType,
GammaDataType,
BetaDataType,
......@@ -166,10 +166,10 @@ struct DeviceLayernorm : public BaseOperator
dim3(arg.gridSize),
dim3(BlockSize),
0,
in_grid_desc_m_k,
x_grid_desc_m_k,
gamma_grid_desc_m_k,
beta_grid_desc_m_k,
out_grid_desc_m_k,
y_grid_desc_m_k,
arg.blkGroupSize,
arg.numBlockTileIteration,
arg.epsilon_,
......@@ -197,7 +197,7 @@ struct DeviceLayernorm : public BaseOperator
return false;
}
if(p_arg_->inLengths_[Rank - 1] % OutDstVectorSize != 0)
if(p_arg_->inLengths_[Rank - 1] % YDstVectorSize != 0)
{
return false;
}
......@@ -241,19 +241,19 @@ struct DeviceLayernorm : public BaseOperator
return true;
};
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> lengths,
const std::vector<index_t> xStrides,
const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides,
const std::vector<int> reduceDims,
const std::vector<index_t> reduceDims,
AccDataType epsilon,
const void* p_x,
const void* p_gamma,
const void* p_beta,
void* p_y)
{
return std::make_unique<Argument>(inLengths,
inStrides,
return std::make_unique<Argument>(lengths,
xStrides,
gammaStrides,
betaStrides,
reduceDims,
......@@ -274,7 +274,7 @@ struct DeviceLayernorm : public BaseOperator
str << "DeviceLayernorm<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
str << "XSrcVectorDim_" << XSrcVectorDim << "_XSrcVectorSize_" << XSrcVectorSize << "_YDstVectorSize_" << YDstVectorSize << ">";
// clang-format on
return str.str();
......
......@@ -21,10 +21,10 @@ template <typename GridwiseReduction,
typename YDataType,
typename AccDataType,
typename GridDesc_M_K>
__global__ void kernel_layernorm(const GridDesc_M_K in_grid_desc_m_k,
__global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_M_K gamma_grid_desc_m_k,
const GridDesc_M_K beta_grid_desc_m_k,
const GridDesc_M_K out_grid_desc_m_k,
const GridDesc_M_K y_grid_desc_m_k,
index_t block_group_size,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
......@@ -33,10 +33,10 @@ __global__ void kernel_layernorm(const GridDesc_M_K in_grid_desc_m_k,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global)
{
GridwiseReduction::Run(in_grid_desc_m_k,
GridwiseReduction::Run(x_grid_desc_m_k,
gamma_grid_desc_m_k,
beta_grid_desc_m_k,
out_grid_desc_m_k,
y_grid_desc_m_k,
block_group_size,
num_k_block_tile_iteration,
epsilon,
......@@ -57,22 +57,22 @@ template <typename XDataType,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t XSrcVectorDim,
index_t XSrcVectorSize,
index_t GammaSrcVectorDim,
index_t GammaSrcVectorSize,
index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize,
index_t OutDstVectorSize,
index_t YDstVectorSize,
bool SweepOnce>
struct GridwiseLayernorm_mk_to_mk
{
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(KThreadSliceSize % OutDstVectorSize == 0),
static_assert(((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)) &&
(KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
......@@ -115,10 +115,10 @@ struct GridwiseLayernorm_mk_to_mk
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const GridDesc_M_K& in_grid_desc_m_k,
__device__ static void Run(const GridDesc_M_K& x_grid_desc_m_k,
const GridDesc_M_K& gamma_grid_desc_m_k,
const GridDesc_M_K& beta_grid_desc_m_k,
const GridDesc_M_K& out_grid_desc_m_k,
const GridDesc_M_K& y_grid_desc_m_k,
index_t block_group_size,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
......@@ -135,14 +135,14 @@ struct GridwiseLayernorm_mk_to_mk
// LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, out_grid_desc_m_k.GetElementSpaceSize());
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
gamma_thread_buf;
......@@ -151,12 +151,12 @@ struct GridwiseLayernorm_mk_to_mk
beta_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
out_thread_buf;
y_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>& in_square_thread_buf = out_thread_buf;
true>& x_square_thread_buf = y_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
......@@ -192,11 +192,11 @@ struct GridwiseLayernorm_mk_to_mk
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
XSrcVectorDim,
XSrcVectorSize,
1,
true>(
in_grid_desc_m_k,
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
......@@ -238,24 +238,26 @@ struct GridwiseLayernorm_mk_to_mk
PassThroughOp,
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
OutDstVectorSize,
XSrcVectorDim,
YDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
out_grid_desc_m_k,
y_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize),
PassThroughOp{});
constexpr auto in_thread_copy_fwd_step =
// Copy x from Cache
// one pass: fwd, second pass: bwd
constexpr auto thread_copy_fwd_step =
make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize);
constexpr auto in_thread_copy_bwd_step =
constexpr auto thread_copy_bwd_step =
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x_global, in_grid_desc_m_k.GetElementSpaceSize());
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
......@@ -264,28 +266,28 @@ struct GridwiseLayernorm_mk_to_mk
p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
// E(x), E[x^2], var(x)
int reduce_length = in_grid_desc_m_k.GetLength(I1);
int reduce_length = x_grid_desc_m_k.GetLength(I1);
index_t reducedTiles = 0;
do
{
threadwise_x_load.Run(in_grid_desc_m_k,
in_global_val_buf,
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
x_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_square_thread_buf(Number<offset>{}) =
in_thread_buf(Number<offset>{}) * in_thread_buf(Number<offset>{});
x_square_thread_buf(Number<offset>{}) =
x_thread_buf(Number<offset>{}) * x_thread_buf(Number<offset>{});
});
});
ThreadwiseSumReduce::Reduce(in_thread_buf, mean_thread_buf);
ThreadwiseSumReduce::Reduce(in_square_thread_buf, mean_square_thread_buf);
ThreadwiseSumReduce::Reduce(x_thread_buf, mean_thread_buf);
ThreadwiseSumReduce::Reduce(x_square_thread_buf, mean_square_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step);
++reducedTiles;
} while(reducedTiles < num_k_block_tile_iteration);
......@@ -303,23 +305,23 @@ struct GridwiseLayernorm_mk_to_mk
});
// y = (x - E[x]) / sqrt(var[x] + epsilon)
auto thread_copy_tail = (num_k_block_tile_iteration - 1) * in_thread_copy_fwd_step;
auto thread_copy_tail = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step;
threadwise_x_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
threadwise_gamma_load.MoveSrcSliceWindow(in_grid_desc_m_k, thread_copy_tail);
threadwise_beta_load.MoveSrcSliceWindow(in_grid_desc_m_k, thread_copy_tail);
threadwise_y_store.MoveDstSliceWindow(out_grid_desc_m_k, thread_copy_tail);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step);
threadwise_gamma_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_tail);
threadwise_beta_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_tail);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail);
reducedTiles = 0;
do
{
if constexpr(!SweepOnce)
{
threadwise_x_load.Run(in_grid_desc_m_k,
in_global_val_buf,
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
x_thread_buf);
}
threadwise_gamma_load.Run(gamma_grid_desc_m_k,
......@@ -338,27 +340,27 @@ struct GridwiseLayernorm_mk_to_mk
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// normalize
out_thread_buf(Number<offset>{}) =
(in_thread_buf(Number<offset>{}) - mean_thread_buf(iM)) /
y_thread_buf(Number<offset>{}) =
(x_thread_buf(Number<offset>{}) - mean_thread_buf(iM)) /
sqrt(var_value_buf(iM) + epsilon);
// affine
out_thread_buf(Number<offset>{}) =
out_thread_buf(Number<offset>{}) * gamma_thread_buf(Number<offset>{}) +
y_thread_buf(Number<offset>{}) =
y_thread_buf(Number<offset>{}) * gamma_thread_buf(Number<offset>{}) +
beta_thread_buf(Number<offset>{});
});
});
threadwise_y_store.Run(thread_buffer_desc,
make_tuple(I0, I0),
out_thread_buf,
out_grid_desc_m_k,
out_global_val_buf);
threadwise_x_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
threadwise_gamma_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
threadwise_beta_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
threadwise_y_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_bwd_step);
y_thread_buf,
y_grid_desc_m_k,
y_global_val_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step);
threadwise_gamma_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step);
threadwise_beta_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_bwd_step);
++reducedTiles;
} while(reducedTiles < num_k_block_tile_iteration);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment