Commit d08aa99e authored by Anthony Chang's avatar Anthony Chang
Browse files

clang-format

parent f906b23d
...@@ -67,8 +67,7 @@ template <typename ALayout, ...@@ -67,8 +67,7 @@ template <typename ALayout,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmLayerNorm_Xdl_CShuffle struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
: public BaseOperator
{ {
using DeviceOp = DeviceGemmLayerNorm_Xdl_CShuffle; using DeviceOp = DeviceGemmLayerNorm_Xdl_CShuffle;
...@@ -463,7 +462,8 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle ...@@ -463,7 +462,8 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_); c_grid_desc_m_n_);
c0_grid_desc_nblock_nperblock_ = GridwiseGemm::MakeC0GridDescriptor_NBlock_NPerBlock(c0_grid_desc_n_); c0_grid_desc_nblock_nperblock_ =
GridwiseGemm::MakeC0GridDescriptor_NBlock_NPerBlock(c0_grid_desc_n_);
// TODO ANT: adopt tensile style workgroup mapping // TODO ANT: adopt tensile style workgroup mapping
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_); block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
...@@ -483,8 +483,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle ...@@ -483,8 +483,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle
C0GridDesc_N c0_grid_desc_n_; C0GridDesc_N c0_grid_desc_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock_;
c0_grid_desc_nblock_nperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
......
...@@ -51,7 +51,7 @@ struct ThreadwiseReduction ...@@ -51,7 +51,7 @@ struct ThreadwiseReduction
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!"); static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
using Op = OpReduce; using Op = OpReduce;
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>; using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
template <typename SrcBufferType, typename DstBufferType> template <typename SrcBufferType, typename DstBufferType>
......
...@@ -12,10 +12,7 @@ template <typename T> ...@@ -12,10 +12,7 @@ template <typename T>
struct PrintAsType<T, typename std::enable_if<std::is_floating_point<T>::value>::type> struct PrintAsType<T, typename std::enable_if<std::is_floating_point<T>::value>::type>
{ {
using type = float; using type = float;
__host__ __device__ static void Print(const T& p) __host__ __device__ static void Print(const T& p) { printf("%.3f ", static_cast<type>(p)); }
{
printf("%.3f ", static_cast<type>(p));
}
}; };
template <> template <>
...@@ -32,10 +29,7 @@ template <typename T> ...@@ -32,10 +29,7 @@ template <typename T>
struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::type> struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::type>
{ {
using type = int; using type = int;
__host__ __device__ static void Print(const T& p) __host__ __device__ static void Print(const T& p) { printf("%d ", static_cast<type>(p)); }
{
printf("%d ", static_cast<type>(p));
}
}; };
} // namespace detail } // namespace detail
......
...@@ -18,21 +18,26 @@ template <typename ADataType, ...@@ -18,21 +18,26 @@ template <typename ADataType,
typename CElementwiseOperation> typename CElementwiseOperation>
struct ReferenceGemmLayernorm : public device::BaseOperator struct ReferenceGemmLayernorm : public device::BaseOperator
{ {
using ReferenceGemmInstance = ReferenceGemm<ADataType, BDataType, AccDataType, AccDataType, using ReferenceGemmInstance = ReferenceGemm<ADataType,
AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>; BDataType,
AccDataType,
AccDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
// D = Layernorm(acc + broadcast(bias)) * broadcast(gamma) + broadcast(beta) // D = Layernorm(acc + broadcast(bias)) * broadcast(gamma) + broadcast(beta)
template <typename InDataType, typename OutDataType, typename ComputeDataType> template <typename InDataType, typename OutDataType, typename ComputeDataType>
static void RunLayernorm(Tensor<OutDataType>& result, static void RunLayernorm(Tensor<OutDataType>& result,
const Tensor<ComputeDataType>& acc, // MxN const Tensor<ComputeDataType>& acc, // MxN
const Tensor<InDataType>& bias, // 1xN const Tensor<InDataType>& bias, // 1xN
const Tensor<InDataType>& gamma, // 1xN const Tensor<InDataType>& gamma, // 1xN
const Tensor<InDataType>& beta, // 1xN const Tensor<InDataType>& beta, // 1xN
const InDataType epsilon = 1e-5) const InDataType epsilon = 1e-5)
{ {
assert(acc.mDesc.GetLengths()[1] == bias.mDesc.GetLengths()[0] && assert(acc.mDesc.GetLengths()[1] == bias.mDesc.GetLengths()[0] &&
acc.mDesc.GetLengths()[1] == gamma.mDesc.GetLengths()[0] && acc.mDesc.GetLengths()[1] == gamma.mDesc.GetLengths()[0] &&
acc.mDesc.GetLengths()[1] == beta.mDesc.GetLengths()[0]); acc.mDesc.GetLengths()[1] == beta.mDesc.GetLengths()[0]);
size_t M = acc.mDesc.GetLengths()[0]; size_t M = acc.mDesc.GetLengths()[0];
size_t N = acc.mDesc.GetLengths()[1]; size_t N = acc.mDesc.GetLengths()[1];
...@@ -127,10 +132,14 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -127,10 +132,14 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
Tensor<AccDataType> acc_m_n(arg.c_m_n_.mDesc); Tensor<AccDataType> acc_m_n(arg.c_m_n_.mDesc);
acc_m_n.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0}); acc_m_n.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0});
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument = ref_gemm.MakeArgument(arg.a_m_k_,
arg.a_m_k_, arg.b_k_n_, acc_m_n, arg.a_element_op_, arg.b_element_op_, arg.c_element_op_); arg.b_k_n_,
acc_m_n,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment