"...resnet50_tensorflow.git" did not exist on "002959e89662d9b17d5b1c3275e7b0f420c902fd"
Commit d08aa99e authored by Anthony Chang's avatar Anthony Chang
Browse files

clang-format

parent f906b23d
...@@ -67,8 +67,7 @@ template <typename ALayout, ...@@ -67,8 +67,7 @@ template <typename ALayout,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmLayerNorm_Xdl_CShuffle struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
: public BaseOperator
{ {
using DeviceOp = DeviceGemmLayerNorm_Xdl_CShuffle; using DeviceOp = DeviceGemmLayerNorm_Xdl_CShuffle;
...@@ -463,7 +462,8 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle ...@@ -463,7 +462,8 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_); c_grid_desc_m_n_);
c0_grid_desc_nblock_nperblock_ = GridwiseGemm::MakeC0GridDescriptor_NBlock_NPerBlock(c0_grid_desc_n_); c0_grid_desc_nblock_nperblock_ =
GridwiseGemm::MakeC0GridDescriptor_NBlock_NPerBlock(c0_grid_desc_n_);
// TODO ANT: adopt tensile style workgroup mapping // TODO ANT: adopt tensile style workgroup mapping
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_); block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
...@@ -483,8 +483,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle ...@@ -483,8 +483,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle
C0GridDesc_N c0_grid_desc_n_; C0GridDesc_N c0_grid_desc_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock_;
c0_grid_desc_nblock_nperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
......
...@@ -31,7 +31,8 @@ __global__ void ...@@ -31,7 +31,8 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_layernorm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, kernel_gemm_layernorm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, // MxN FloatC* __restrict__ p_c_grid, // MxN
const FloatC* __restrict__ p_c0_bias_grid, // 1xN const FloatC* __restrict__ p_c0_bias_grid, // 1xN
...@@ -44,8 +45,7 @@ __global__ void ...@@ -44,8 +45,7 @@ __global__ void
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const C0GridDescriptor_NBlock_NPerBlock const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock,
c0_grid_desc_nblock_nperblock,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...@@ -247,7 +247,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -247,7 +247,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// static check: all waves in the workgroups combined must cover whole N extent in order // static check: all waves in the workgroups combined must cover whole N extent in order
// to have efficient N-dim reduction // to have efficient N-dim reduction
static_assert(CShuffleNXdlPerWavePerShuffle == NXdlPerWave, "condition not met for efficient layernorm"); static_assert(CShuffleNXdlPerWavePerShuffle == NXdlPerWave,
"condition not met for efficient layernorm");
// check gridwise gemm pipeline // check gridwise gemm pipeline
const auto num_k_loop = K / KPerBlock; const auto num_k_loop = K / KPerBlock;
...@@ -357,14 +358,15 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -357,14 +358,15 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using C0GridDescriptor_NBlock_NPerBlock = remove_cvref_t<decltype( using C0GridDescriptor_NBlock_NPerBlock =
MakeC0GridDescriptor_NBlock_NPerBlock(C0GridDesc_N{}))>; remove_cvref_t<decltype(MakeC0GridDescriptor_NBlock_NPerBlock(C0GridDesc_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
template <bool HasMainKBlockLoop, typename Block2CTileMap> template <bool HasMainKBlockLoop, typename Block2CTileMap>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const FloatC* __restrict__ p_c0_bias_grid, // 1xN const FloatC* __restrict__ p_c0_bias_grid, // 1xN
...@@ -378,8 +380,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -378,8 +380,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const C0GridDescriptor_NBlock_NPerBlock& const C0GridDescriptor_NBlock_NPerBlock& c0_grid_desc_nblock_nperblock,
c0_grid_desc_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -802,7 +803,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -802,7 +803,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin, tensor_operation::element_wise::PassThrough{}}; true>{c_reduce_block_desc_mperblock_nperblock,
c_reduce_thread_data_idx_begin,
tensor_operation::element_wise::PassThrough{}};
auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatC, FloatC,
...@@ -814,9 +817,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -814,9 +817,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
3, 3,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
1, 1,
true>(c0_grid_desc_mblock_mperblock_nblock_nperblock, true>(
make_multi_index( c0_grid_desc_mblock_mperblock_nblock_nperblock,
I0, make_multi_index(I0,
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
I0, I0,
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1])); n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]));
...@@ -869,7 +872,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -869,7 +872,6 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_tuple(I0, I0), make_tuple(I0, I0),
c_reduce_thread_buf); c_reduce_thread_buf);
c0_thread_copy_global_to_vgpr.Run( c0_thread_copy_global_to_vgpr.Run(
c0_grid_desc_mblock_mperblock_nblock_nperblock, c0_grid_desc_mblock_mperblock_nblock_nperblock,
c0_bias_grid_buf, c0_bias_grid_buf,
...@@ -877,9 +879,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -877,9 +879,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
c0_thread_buf); c0_thread_buf);
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}([&](auto i) { static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
c_reduce_thread_buf(i) += c0_thread_buf(i); [&](auto i) { c_reduce_thread_buf(i) += c0_thread_buf(i); });
});
using ThreadwiseReduceD0 = using ThreadwiseReduceD0 =
ThreadwiseReduction<FloatReduceAcc, ThreadwiseReduction<FloatReduceAcc,
...@@ -908,7 +909,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -908,7 +909,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
ThreadwiseReduceD1::Reduce(c_reduce_thread_buf, d1_thread_buf); ThreadwiseReduceD1::Reduce(c_reduce_thread_buf, d1_thread_buf);
// reduce within workgroup // reduce within workgroup
using BlockwiseReduce = PartitionedBlockwiseReduction<FloatReduceAcc, using BlockwiseReduce = PartitionedBlockwiseReduction<
FloatReduceAcc,
BlockSize, BlockSize,
CReduceThreadClusterLengths_MPerBlock_NPerBlock, // ThreadClusterLengths_M_K CReduceThreadClusterLengths_MPerBlock_NPerBlock, // ThreadClusterLengths_M_K
Sequence<1, 0>, // ThreadClusterArrangeOrder Sequence<1, 0>, // ThreadClusterArrangeOrder
...@@ -917,13 +919,17 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -917,13 +919,17 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static_for<0, mreduce_per_thread, 1>{}([&](auto i) { static_for<0, mreduce_per_thread, 1>{}([&](auto i) {
block_sync_lds(); block_sync_lds();
BlockwiseReduce::Reduce(d_reduce_work_buf, d0_thread_buf(i)); // blockwise reduced sum BlockwiseReduce::Reduce(d_reduce_work_buf,
d0_thread_buf(i)); // blockwise reduced sum
block_sync_lds(); block_sync_lds();
BlockwiseReduce::Reduce(d_reduce_work_buf, d1_thread_buf(i)); // blockwise reduced squared sum BlockwiseReduce::Reduce(d_reduce_work_buf,
d1_thread_buf(i)); // blockwise reduced squared sum
}); });
// normalize // normalize
const index_t NRaw = c_grid_desc_mblock_mperblock_nblock_nperblock.GetTransforms()[I0].GetUpperLengths()[I1]; // TODO: proper handle const index_t NRaw =
c_grid_desc_mblock_mperblock_nblock_nperblock.GetTransforms()[I0]
.GetUpperLengths()[I1]; // TODO: proper handle
static_for<0, mreduce_per_thread, 1>{}([&](auto im) { static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
static_for<0, nreduce_per_thread, 1>{}([&](auto in) { static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
...@@ -941,7 +947,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -941,7 +947,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
FloatReduceAcc denom = c_reduce_thread_buf(dst_offset) - avg_sum; FloatReduceAcc denom = c_reduce_thread_buf(dst_offset) - avg_sum;
FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum; FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum;
FloatReduceAcc divisor_sqrt; FloatReduceAcc divisor_sqrt;
tensor_operation::element_wise::UnarySqrt<FloatReduceAcc, FloatReduceAcc>{}(divisor_sqrt, divisor); tensor_operation::element_wise::UnarySqrt<FloatReduceAcc,
FloatReduceAcc>{}(
divisor_sqrt, divisor);
c_reduce_thread_buf(dst_offset) = denom / divisor_sqrt; c_reduce_thread_buf(dst_offset) = denom / divisor_sqrt;
}); });
...@@ -955,7 +963,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -955,7 +963,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
c0_thread_buf); c0_thread_buf);
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}([&](auto i) { static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) {
c_reduce_thread_buf(i) *= c0_thread_buf(i); // * gamma c_reduce_thread_buf(i) *= c0_thread_buf(i); // * gamma
}); });
...@@ -966,7 +975,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -966,7 +975,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
c0_thread_buf); c0_thread_buf);
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}([&](auto i) { static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) {
c_reduce_thread_buf(i) += c0_thread_buf(i); // + beta c_reduce_thread_buf(i) += c0_thread_buf(i); // + beta
}); });
......
...@@ -12,10 +12,7 @@ template <typename T> ...@@ -12,10 +12,7 @@ template <typename T>
struct PrintAsType<T, typename std::enable_if<std::is_floating_point<T>::value>::type> struct PrintAsType<T, typename std::enable_if<std::is_floating_point<T>::value>::type>
{ {
using type = float; using type = float;
__host__ __device__ static void Print(const T& p) __host__ __device__ static void Print(const T& p) { printf("%.3f ", static_cast<type>(p)); }
{
printf("%.3f ", static_cast<type>(p));
}
}; };
template <> template <>
...@@ -32,10 +29,7 @@ template <typename T> ...@@ -32,10 +29,7 @@ template <typename T>
struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::type> struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::type>
{ {
using type = int; using type = int;
__host__ __device__ static void Print(const T& p) __host__ __device__ static void Print(const T& p) { printf("%d ", static_cast<type>(p)); }
{
printf("%d ", static_cast<type>(p));
}
}; };
} // namespace detail } // namespace detail
......
...@@ -18,8 +18,13 @@ template <typename ADataType, ...@@ -18,8 +18,13 @@ template <typename ADataType,
typename CElementwiseOperation> typename CElementwiseOperation>
struct ReferenceGemmLayernorm : public device::BaseOperator struct ReferenceGemmLayernorm : public device::BaseOperator
{ {
using ReferenceGemmInstance = ReferenceGemm<ADataType, BDataType, AccDataType, AccDataType, using ReferenceGemmInstance = ReferenceGemm<ADataType,
AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>; BDataType,
AccDataType,
AccDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
// D = Layernorm(acc + broadcast(bias)) * broadcast(gamma) + broadcast(beta) // D = Layernorm(acc + broadcast(bias)) * broadcast(gamma) + broadcast(beta)
template <typename InDataType, typename OutDataType, typename ComputeDataType> template <typename InDataType, typename OutDataType, typename ComputeDataType>
...@@ -129,8 +134,12 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -129,8 +134,12 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument = ref_gemm.MakeArgument(arg.a_m_k_,
arg.a_m_k_, arg.b_k_n_, acc_m_n, arg.a_element_op_, arg.b_element_op_, arg.c_element_op_); arg.b_k_n_,
acc_m_n,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment