Commit d08aa99e authored by Anthony Chang's avatar Anthony Chang
Browse files

clang-format

parent f906b23d
......@@ -67,8 +67,7 @@ template <typename ALayout,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmLayerNorm_Xdl_CShuffle
: public BaseOperator
struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
{
using DeviceOp = DeviceGemmLayerNorm_Xdl_CShuffle;
......@@ -463,7 +462,8 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
c0_grid_desc_nblock_nperblock_ = GridwiseGemm::MakeC0GridDescriptor_NBlock_NPerBlock(c0_grid_desc_n_);
c0_grid_desc_nblock_nperblock_ =
GridwiseGemm::MakeC0GridDescriptor_NBlock_NPerBlock(c0_grid_desc_n_);
// TODO ANT: adopt tensile style workgroup mapping
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
......@@ -483,8 +483,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle
C0GridDesc_N c0_grid_desc_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock
c0_grid_desc_nblock_nperblock_;
typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
......
......@@ -31,22 +31,22 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_layernorm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, // MxN
const FloatC* __restrict__ p_c0_bias_grid, // 1xN
const FloatC* __restrict__ p_c0_gamma_grid, // 1xN
const FloatC* __restrict__ p_c0_beta_grid, // 1xN
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const C0GridDescriptor_NBlock_NPerBlock
c0_grid_desc_nblock_nperblock,
const Block2CTileMap block_2_ctile_map)
kernel_gemm_layernorm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, // MxN
const FloatC* __restrict__ p_c0_bias_grid, // 1xN
const FloatC* __restrict__ p_c0_gamma_grid, // 1xN
const FloatC* __restrict__ p_c0_beta_grid, // 1xN
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -247,7 +247,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// static check: all waves in the workgroups combined must cover whole N extent in order
// to have efficient N-dim reduction
static_assert(CShuffleNXdlPerWavePerShuffle == NXdlPerWave, "condition not met for efficient layernorm");
static_assert(CShuffleNXdlPerWavePerShuffle == NXdlPerWave,
"condition not met for efficient layernorm");
// check gridwise gemm pipeline
const auto num_k_loop = K / KPerBlock;
......@@ -357,30 +358,30 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using C0GridDescriptor_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeC0GridDescriptor_NBlock_NPerBlock(C0GridDesc_N{}))>;
using C0GridDescriptor_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeC0GridDescriptor_NBlock_NPerBlock(C0GridDesc_N{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
template <bool HasMainKBlockLoop, typename Block2CTileMap>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const FloatC* __restrict__ p_c0_bias_grid, // 1xN
const FloatC* __restrict__ p_c0_gamma_grid, // 1xN
const FloatC* __restrict__ p_c0_beta_grid, // 1xN
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const C0GridDescriptor_NBlock_NPerBlock&
c0_grid_desc_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map)
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const FloatC* __restrict__ p_c0_bias_grid, // 1xN
const FloatC* __restrict__ p_c0_gamma_grid, // 1xN
const FloatC* __restrict__ p_c0_beta_grid, // 1xN
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const C0GridDescriptor_NBlock_NPerBlock& c0_grid_desc_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
......@@ -686,9 +687,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const auto c0_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c0_grid_desc_nblock_nperblock,
make_tuple(make_insert_transform(I1),
make_insert_transform(I1),
make_pass_through_transform(NBlock),
make_pass_through_transform(NPerBlock)),
make_insert_transform(I1),
make_pass_through_transform(NBlock),
make_pass_through_transform(NPerBlock)),
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
......@@ -802,7 +803,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
InMemoryDataOperationEnum::Set,
1,
true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin, tensor_operation::element_wise::PassThrough{}};
true>{c_reduce_block_desc_mperblock_nperblock,
c_reduce_thread_data_idx_begin,
tensor_operation::element_wise::PassThrough{}};
auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatC,
......@@ -814,12 +817,12 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
3,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
1,
true>(c0_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(
I0,
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
I0,
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]));
true>(
c0_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(I0,
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
I0,
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]));
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
......@@ -862,121 +865,128 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// layernorm
{
// load from LDS and global, add bias
c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
c_shuffle_block_buf,
c_reduce_thread_desc_mperblock_nperblock,
make_tuple(I0, I0),
c_reduce_thread_buf);
c0_thread_copy_global_to_vgpr.Run(
c0_grid_desc_mblock_mperblock_nblock_nperblock,
c0_bias_grid_buf,
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
c0_thread_buf);
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}([&](auto i) {
c_reduce_thread_buf(i) += c0_thread_buf(i);
});
using ThreadwiseReduceD0 =
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
reduce::Add<FloatReduceAcc>,
false>;
using ThreadwiseReduceD1 =
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
reduce::SquaredAdd<FloatReduceAcc>,
false>;
const auto d0_zeroVal = ThreadwiseReduceD0::Op::GetReductionZeroVal();
const auto d1_zeroVal = ThreadwiseReduceD1::Op::GetReductionZeroVal();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto i) { d0_thread_buf(i) = d0_zeroVal; });
static_for<0, mreduce_per_thread, 1>{}(
[&](auto i) { d1_thread_buf(i) = d1_zeroVal; });
// reduce sum in VGPR
ThreadwiseReduceD0::Reduce(c_reduce_thread_buf, d0_thread_buf);
// reduce squared sum in VGPR
ThreadwiseReduceD1::Reduce(c_reduce_thread_buf, d1_thread_buf);
// reduce within workgroup
using BlockwiseReduce = PartitionedBlockwiseReduction<FloatReduceAcc,
BlockSize,
CReduceThreadClusterLengths_MPerBlock_NPerBlock, // ThreadClusterLengths_M_K
Sequence<1, 0>, // ThreadClusterArrangeOrder
reduce::Add<FloatReduceAcc>,
false>;
static_for<0, mreduce_per_thread, 1>{}([&](auto i) {
block_sync_lds();
BlockwiseReduce::Reduce(d_reduce_work_buf, d0_thread_buf(i)); // blockwise reduced sum
block_sync_lds();
BlockwiseReduce::Reduce(d_reduce_work_buf, d1_thread_buf(i)); // blockwise reduced squared sum
});
// normalize
const index_t NRaw = c_grid_desc_mblock_mperblock_nblock_nperblock.GetTransforms()[I0].GetUpperLengths()[I1]; // TODO: proper handle
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
constexpr auto dst_offset =
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{};
constexpr auto src_offset =
Number<d_reduce_thread_desc_mperblock.CalculateOffset(
make_tuple(im))>{};
FloatReduceAcc avg_sum = d0_thread_buf(src_offset) / NRaw;
FloatReduceAcc avg_squared_sum = d1_thread_buf(src_offset) / NRaw;
FloatReduceAcc denom = c_reduce_thread_buf(dst_offset) - avg_sum;
FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum;
FloatReduceAcc divisor_sqrt;
tensor_operation::element_wise::UnarySqrt<FloatReduceAcc, FloatReduceAcc>{}(divisor_sqrt, divisor);
// load from LDS and global, add bias
c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
c_shuffle_block_buf,
c_reduce_thread_desc_mperblock_nperblock,
make_tuple(I0, I0),
c_reduce_thread_buf);
c0_thread_copy_global_to_vgpr.Run(
c0_grid_desc_mblock_mperblock_nblock_nperblock,
c0_bias_grid_buf,
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
c0_thread_buf);
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) { c_reduce_thread_buf(i) += c0_thread_buf(i); });
using ThreadwiseReduceD0 =
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
reduce::Add<FloatReduceAcc>,
false>;
using ThreadwiseReduceD1 =
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
reduce::SquaredAdd<FloatReduceAcc>,
false>;
const auto d0_zeroVal = ThreadwiseReduceD0::Op::GetReductionZeroVal();
const auto d1_zeroVal = ThreadwiseReduceD1::Op::GetReductionZeroVal();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto i) { d0_thread_buf(i) = d0_zeroVal; });
static_for<0, mreduce_per_thread, 1>{}(
[&](auto i) { d1_thread_buf(i) = d1_zeroVal; });
// reduce sum in VGPR
ThreadwiseReduceD0::Reduce(c_reduce_thread_buf, d0_thread_buf);
// reduce squared sum in VGPR
ThreadwiseReduceD1::Reduce(c_reduce_thread_buf, d1_thread_buf);
// reduce within workgroup
using BlockwiseReduce = PartitionedBlockwiseReduction<
FloatReduceAcc,
BlockSize,
CReduceThreadClusterLengths_MPerBlock_NPerBlock, // ThreadClusterLengths_M_K
Sequence<1, 0>, // ThreadClusterArrangeOrder
reduce::Add<FloatReduceAcc>,
false>;
static_for<0, mreduce_per_thread, 1>{}([&](auto i) {
block_sync_lds();
BlockwiseReduce::Reduce(d_reduce_work_buf,
d0_thread_buf(i)); // blockwise reduced sum
block_sync_lds();
BlockwiseReduce::Reduce(d_reduce_work_buf,
d1_thread_buf(i)); // blockwise reduced squared sum
});
c_reduce_thread_buf(dst_offset) = denom / divisor_sqrt;
// normalize
const index_t NRaw =
c_grid_desc_mblock_mperblock_nblock_nperblock.GetTransforms()[I0]
.GetUpperLengths()[I1]; // TODO: proper handle
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
constexpr auto dst_offset =
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{};
constexpr auto src_offset =
Number<d_reduce_thread_desc_mperblock.CalculateOffset(
make_tuple(im))>{};
FloatReduceAcc avg_sum = d0_thread_buf(src_offset) / NRaw;
FloatReduceAcc avg_squared_sum = d1_thread_buf(src_offset) / NRaw;
FloatReduceAcc denom = c_reduce_thread_buf(dst_offset) - avg_sum;
FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum;
FloatReduceAcc divisor_sqrt;
tensor_operation::element_wise::UnarySqrt<FloatReduceAcc,
FloatReduceAcc>{}(
divisor_sqrt, divisor);
c_reduce_thread_buf(dst_offset) = denom / divisor_sqrt;
});
});
});
// scaling
c0_thread_copy_global_to_vgpr.Run(
c0_grid_desc_mblock_mperblock_nblock_nperblock,
c0_gamma_grid_buf,
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
c0_thread_buf);
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}([&](auto i) {
c_reduce_thread_buf(i) *= c0_thread_buf(i); // * gamma
});
c0_thread_copy_global_to_vgpr.Run(
c0_grid_desc_mblock_mperblock_nblock_nperblock,
c0_beta_grid_buf,
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
c0_thread_buf);
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}([&](auto i) {
c_reduce_thread_buf(i) += c0_thread_buf(i); // + beta
});
block_sync_lds();
// scaling
c0_thread_copy_global_to_vgpr.Run(
c0_grid_desc_mblock_mperblock_nblock_nperblock,
c0_gamma_grid_buf,
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
c0_thread_buf);
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) {
c_reduce_thread_buf(i) *= c0_thread_buf(i); // * gamma
});
c0_thread_copy_global_to_vgpr.Run(
c0_grid_desc_mblock_mperblock_nblock_nperblock,
c0_beta_grid_buf,
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
c0_thread_buf);
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) {
c_reduce_thread_buf(i) += c0_thread_buf(i); // + beta
});
block_sync_lds();
c_reduce_thread_copy_vgpr_to_lds.Run(c_reduce_thread_desc_mperblock_nperblock,
make_tuple(I0, I0),
c_reduce_thread_buf,
c_reduce_block_desc_mperblock_nperblock,
c_shuffle_block_buf);
c_reduce_thread_copy_vgpr_to_lds.Run(c_reduce_thread_desc_mperblock_nperblock,
make_tuple(I0, I0),
c_reduce_thread_buf,
c_reduce_block_desc_mperblock_nperblock,
c_shuffle_block_buf);
} // end layernorm
......
......@@ -51,7 +51,7 @@ struct ThreadwiseReduction
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
using Op = OpReduce;
using Op = OpReduce;
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
template <typename SrcBufferType, typename DstBufferType>
......
......@@ -12,10 +12,7 @@ template <typename T>
struct PrintAsType<T, typename std::enable_if<std::is_floating_point<T>::value>::type>
{
using type = float;
__host__ __device__ static void Print(const T& p)
{
printf("%.3f ", static_cast<type>(p));
}
__host__ __device__ static void Print(const T& p) { printf("%.3f ", static_cast<type>(p)); }
};
template <>
......@@ -32,10 +29,7 @@ template <typename T>
struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::type>
{
using type = int;
__host__ __device__ static void Print(const T& p)
{
printf("%d ", static_cast<type>(p));
}
__host__ __device__ static void Print(const T& p) { printf("%d ", static_cast<type>(p)); }
};
} // namespace detail
......
......@@ -18,21 +18,26 @@ template <typename ADataType,
typename CElementwiseOperation>
struct ReferenceGemmLayernorm : public device::BaseOperator
{
using ReferenceGemmInstance = ReferenceGemm<ADataType, BDataType, AccDataType, AccDataType,
AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>;
using ReferenceGemmInstance = ReferenceGemm<ADataType,
BDataType,
AccDataType,
AccDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
// D = Layernorm(acc + broadcast(bias)) * broadcast(gamma) + broadcast(beta)
template <typename InDataType, typename OutDataType, typename ComputeDataType>
static void RunLayernorm(Tensor<OutDataType>& result,
const Tensor<ComputeDataType>& acc, // MxN
const Tensor<InDataType>& bias, // 1xN
const Tensor<InDataType>& gamma, // 1xN
const Tensor<InDataType>& beta, // 1xN
const InDataType epsilon = 1e-5)
const Tensor<ComputeDataType>& acc, // MxN
const Tensor<InDataType>& bias, // 1xN
const Tensor<InDataType>& gamma, // 1xN
const Tensor<InDataType>& beta, // 1xN
const InDataType epsilon = 1e-5)
{
assert(acc.mDesc.GetLengths()[1] == bias.mDesc.GetLengths()[0] &&
acc.mDesc.GetLengths()[1] == gamma.mDesc.GetLengths()[0] &&
acc.mDesc.GetLengths()[1] == beta.mDesc.GetLengths()[0]);
acc.mDesc.GetLengths()[1] == gamma.mDesc.GetLengths()[0] &&
acc.mDesc.GetLengths()[1] == beta.mDesc.GetLengths()[0]);
size_t M = acc.mDesc.GetLengths()[0];
size_t N = acc.mDesc.GetLengths()[1];
......@@ -127,10 +132,14 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
Tensor<AccDataType> acc_m_n(arg.c_m_n_.mDesc);
acc_m_n.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0});
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
arg.a_m_k_, arg.b_k_n_, acc_m_n, arg.a_element_op_, arg.b_element_op_, arg.c_element_op_);
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(arg.a_m_k_,
arg.b_k_n_,
acc_m_n,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
ref_invoker.Run(ref_argument);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment