Commit 76f2b6cd authored by danyao12's avatar danyao12
Browse files

merge develop to attn-train-develop-qloop

parents 9b4c780a 1ee99dca
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support #error Should compile this file with ck::int4_t support
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <initializer_list> #include <initializer_list>
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "ck/utility/reduction_enums.hpp" #include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_reduce.hpp"
#include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -16,7 +17,6 @@ ...@@ -16,7 +17,6 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_reduction.hpp"
#include "reduce_example_common.hpp" #include "reduce_example_common.hpp"
...@@ -236,29 +236,6 @@ int reduce_blockwise_impl(bool do_verification, ...@@ -236,29 +236,6 @@ int reduce_blockwise_impl(bool do_verification,
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator( reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
static_cast<int32_t>(reduce_total_length)); static_cast<int32_t>(reduce_total_length));
if(do_verification)
{
ReductionHost<InOutDataType,
AccDataType,
InOutDataType,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
Rank,
NumReduceDim,
PropagateNan,
OutputIndex>
hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims);
hostReduce.Run(alpha,
in.mData.data(),
beta,
out_ref.mData.data(),
out_indices_ref.mData.data(),
in_elementwise_op,
acc_elementwise_op);
};
std::array<index_t, Rank> arrInLengths; std::array<index_t, Rank> arrInLengths;
std::array<index_t, Rank> arrInStrides; std::array<index_t, Rank> arrInStrides;
std::array<index_t, NumOutDim> arrOutLengths; std::array<index_t, NumOutDim> arrOutLengths;
...@@ -269,6 +246,48 @@ int reduce_blockwise_impl(bool do_verification, ...@@ -269,6 +246,48 @@ int reduce_blockwise_impl(bool do_verification,
ck::ranges::copy(outLengths, arrOutLengths.begin()); ck::ranges::copy(outLengths, arrOutLengths.begin());
ck::ranges::copy(outStrides, arrOutStrides.begin()); ck::ranges::copy(outStrides, arrOutStrides.begin());
if(do_verification)
{
using ReferenceReduceInstance =
ck::tensor_operation::host::ReferenceReduce<InOutDataType,
AccDataType,
InOutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
OutputIndex>;
auto reduce_ref = ReferenceReduceInstance{};
auto argument_ptr_ref = reduce_ref.MakeArgumentPointer(arrInLengths,
arrInStrides,
arrOutLengths,
arrOutStrides,
reduceDims,
static_cast<double>(alpha),
static_cast<double>(beta),
in.mData.data(),
nullptr,
out_ref.mData.data(),
out_indices_ref.mData.data(),
in_elementwise_op,
acc_elementwise_op);
if(!reduce_ref.IsSupportedArgument(argument_ptr_ref.get()))
{
std::cout << "The runtime parameters not supported by the reduce reference, exiting!"
<< std::endl;
return (false);
};
auto invoker_ptr_ref = reduce_ref.MakeInvokerPointer();
invoker_ptr_ref->Run(argument_ptr_ref.get());
};
auto reduce = DeviceReduceInstance{}; auto reduce = DeviceReduceInstance{};
auto argument_ptr = reduce.MakeArgumentPointer(arrInLengths, auto argument_ptr = reduce.MakeArgumentPointer(arrInLengths,
...@@ -276,8 +295,8 @@ int reduce_blockwise_impl(bool do_verification, ...@@ -276,8 +295,8 @@ int reduce_blockwise_impl(bool do_verification,
arrOutLengths, arrOutLengths,
arrOutStrides, arrOutStrides,
reduceDims, reduceDims,
alpha, static_cast<double>(alpha),
beta, static_cast<double>(beta),
in_dev.GetDeviceBuffer(), in_dev.GetDeviceBuffer(),
nullptr, nullptr,
out_dev.GetDeviceBuffer(), out_dev.GetDeviceBuffer(),
...@@ -287,9 +306,8 @@ int reduce_blockwise_impl(bool do_verification, ...@@ -287,9 +306,8 @@ int reduce_blockwise_impl(bool do_verification,
if(!reduce.IsSupportedArgument(argument_ptr.get())) if(!reduce.IsSupportedArgument(argument_ptr.get()))
{ {
std::cerr std::cerr << "The runtime parameters not supported by the DeviceReduce instance, exiting!"
<< "The runtime parameters seems not supported by the DeviceReduce instance, exiting!" << std::endl;
<< std::endl;
return (-2); return (-2);
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
#include "ck/utility/reduction_enums.hpp" #include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_reduce.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_reduction.hpp"
using namespace ck; using namespace ck;
using namespace ck::tensor_operation::device; using namespace ck::tensor_operation::device;
...@@ -97,8 +97,8 @@ int main(int argc, char* argv[]) ...@@ -97,8 +97,8 @@ int main(int argc, char* argv[])
// const std::array<int, 3> invariantDims_2 = {0, 1, 2}; // const std::array<int, 3> invariantDims_2 = {0, 1, 2};
// used by the host reduction // used by the host reduction
const std::array<int, 2> reduceDims = {3, 4}; const std::array<int, 2> reduceDims = {3, 4};
const std::array<int, 3> invariantDims = {0, 1, 2}; // const std::array<int, 3> invariantDims = {0, 1, 2};
const std::vector<size_t> inLengths_1 = {64, 320, 80, 4, 128}; const std::vector<size_t> inLengths_1 = {64, 320, 80, 4, 128};
...@@ -191,29 +191,6 @@ int main(int argc, char* argv[]) ...@@ -191,29 +191,6 @@ int main(int argc, char* argv[])
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator( reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
static_cast<int32_t>(reduce_total_length)); static_cast<int32_t>(reduce_total_length));
if(do_verify)
{
ReductionHost<InOutDataType,
AccDataType,
InOutDataType,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
5, // Rank
2, // NumReduceDim
PropagateNan,
OutputIndex>
hostReduce(in_1.mDesc, out_ref.mDesc, invariantDims, reduceDims);
hostReduce.Run(alpha,
in_1.mData.data(),
beta,
out_ref.mData.data(),
nullptr,
in_elementwise_op,
acc_elementwise_op);
};
std::array<index_t, 5> arrInLengths_1; std::array<index_t, 5> arrInLengths_1;
std::array<index_t, 5> arrInStrides_1; std::array<index_t, 5> arrInStrides_1;
std::array<index_t, 4> arrInLengths_2; std::array<index_t, 4> arrInLengths_2;
...@@ -228,6 +205,48 @@ int main(int argc, char* argv[]) ...@@ -228,6 +205,48 @@ int main(int argc, char* argv[])
ck::ranges::copy(outLengths, arrOutLengths.begin()); ck::ranges::copy(outLengths, arrOutLengths.begin());
ck::ranges::copy(outStrides, arrOutStrides.begin()); ck::ranges::copy(outStrides, arrOutStrides.begin());
if(do_verify)
{
using ReferenceReduceInstance =
ck::tensor_operation::host::ReferenceReduce<InOutDataType,
AccDataType,
InOutDataType,
5,
2,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
OutputIndex>;
auto reduce_ref = ReferenceReduceInstance{};
auto argument_ptr_ref = reduce_ref.MakeArgumentPointer(arrInLengths_1,
arrInStrides_1,
arrOutLengths,
arrOutStrides,
reduceDims,
static_cast<double>(alpha),
static_cast<double>(beta),
in_1.mData.data(),
nullptr,
out_ref.mData.data(),
nullptr,
in_elementwise_op,
acc_elementwise_op);
if(!reduce_ref.IsSupportedArgument(argument_ptr_ref.get()))
{
std::cout << "The runtime parameters not supported by the reduce reference, exiting!"
<< std::endl;
return (false);
};
auto invoker_ptr_ref = reduce_ref.MakeInvokerPointer();
invoker_ptr_ref->Run(argument_ptr_ref.get());
};
auto reduce_1 = DeviceReduceInstance_1{}; auto reduce_1 = DeviceReduceInstance_1{};
auto argument_ptr_1 = reduce_1.MakeArgumentPointer(arrInLengths_1, auto argument_ptr_1 = reduce_1.MakeArgumentPointer(arrInLengths_1,
...@@ -235,8 +254,8 @@ int main(int argc, char* argv[]) ...@@ -235,8 +254,8 @@ int main(int argc, char* argv[])
arrInLengths_2, arrInLengths_2,
arrInStrides_2, arrInStrides_2,
reduceDims_1, reduceDims_1,
1.0f, 1.0,
0.0f, 0.0,
in_1_dev.GetDeviceBuffer(), in_1_dev.GetDeviceBuffer(),
nullptr, nullptr,
in_2_dev.GetDeviceBuffer(), in_2_dev.GetDeviceBuffer(),
...@@ -246,9 +265,8 @@ int main(int argc, char* argv[]) ...@@ -246,9 +265,8 @@ int main(int argc, char* argv[])
if(!reduce_1.IsSupportedArgument(argument_ptr_1.get())) if(!reduce_1.IsSupportedArgument(argument_ptr_1.get()))
{ {
std::cout std::cout << "The runtime parameters seems supported by the DeviceReduce instance, exiting!"
<< "The runtime parameters seems not supported by the DeviceReduce instance, exiting!" << std::endl;
<< std::endl;
}; };
auto invoker_ptr_1 = reduce_1.MakeInvokerPointer(); auto invoker_ptr_1 = reduce_1.MakeInvokerPointer();
...@@ -260,8 +278,8 @@ int main(int argc, char* argv[]) ...@@ -260,8 +278,8 @@ int main(int argc, char* argv[])
arrOutLengths, arrOutLengths,
arrOutStrides, arrOutStrides,
reduceDims_2, reduceDims_2,
alpha, static_cast<double>(alpha),
beta, static_cast<double>(beta),
in_2_dev.GetDeviceBuffer(), in_2_dev.GetDeviceBuffer(),
nullptr, nullptr,
out_dev.GetDeviceBuffer(), out_dev.GetDeviceBuffer(),
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <initializer_list> #include <initializer_list>
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "ck/utility/reduction_enums.hpp" #include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_reduce.hpp"
#include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -16,7 +17,6 @@ ...@@ -16,7 +17,6 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_reduction.hpp"
#include "reduce_example_common.hpp" #include "reduce_example_common.hpp"
...@@ -149,29 +149,6 @@ int reduce_multiblock_atomic_add_impl(bool do_verification, ...@@ -149,29 +149,6 @@ int reduce_multiblock_atomic_add_impl(bool do_verification,
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator( reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
static_cast<int32_t>(reduce_total_length)); static_cast<int32_t>(reduce_total_length));
if(do_verification)
{
ReductionHost<InOutDataType,
AccDataType,
InOutDataType,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
Rank,
NumReduceDim,
PropagateNan,
false>
hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims);
hostReduce.Run(alpha,
in.mData.data(),
beta,
out_ref.mData.data(),
nullptr,
in_elementwise_op,
acc_elementwise_op);
};
std::array<index_t, Rank> arrInLengths; std::array<index_t, Rank> arrInLengths;
std::array<index_t, Rank> arrInStrides; std::array<index_t, Rank> arrInStrides;
std::array<index_t, NumOutDim> arrOutLengths; std::array<index_t, NumOutDim> arrOutLengths;
...@@ -182,6 +159,48 @@ int reduce_multiblock_atomic_add_impl(bool do_verification, ...@@ -182,6 +159,48 @@ int reduce_multiblock_atomic_add_impl(bool do_verification,
ck::ranges::copy(outLengths, arrOutLengths.begin()); ck::ranges::copy(outLengths, arrOutLengths.begin());
ck::ranges::copy(outStrides, arrOutStrides.begin()); ck::ranges::copy(outStrides, arrOutStrides.begin());
if(do_verification)
{
using ReferenceReduceInstance =
ck::tensor_operation::host::ReferenceReduce<InOutDataType,
AccDataType,
InOutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
false>;
auto reduce_ref = ReferenceReduceInstance{};
auto argument_ptr_ref = reduce_ref.MakeArgumentPointer(arrInLengths,
arrInStrides,
arrOutLengths,
arrOutStrides,
reduceDims,
static_cast<double>(alpha),
static_cast<double>(beta),
in.mData.data(),
nullptr,
out_ref.mData.data(),
nullptr,
in_elementwise_op,
acc_elementwise_op);
if(!reduce_ref.IsSupportedArgument(argument_ptr_ref.get()))
{
std::cout << "The runtime parameters not supported by the reduce reference, exiting!"
<< std::endl;
return (false);
};
auto invoker_ptr_ref = reduce_ref.MakeInvokerPointer();
invoker_ptr_ref->Run(argument_ptr_ref.get());
};
auto reduce = DeviceReduceInstance{}; auto reduce = DeviceReduceInstance{};
auto argument_ptr = reduce.MakeArgumentPointer(arrInLengths, auto argument_ptr = reduce.MakeArgumentPointer(arrInLengths,
...@@ -189,8 +208,8 @@ int reduce_multiblock_atomic_add_impl(bool do_verification, ...@@ -189,8 +208,8 @@ int reduce_multiblock_atomic_add_impl(bool do_verification,
arrOutLengths, arrOutLengths,
arrOutStrides, arrOutStrides,
reduceDims, reduceDims,
alpha, static_cast<double>(alpha),
beta, static_cast<double>(beta),
in_dev.GetDeviceBuffer(), in_dev.GetDeviceBuffer(),
nullptr, nullptr,
out_dev.GetDeviceBuffer(), out_dev.GetDeviceBuffer(),
...@@ -200,9 +219,8 @@ int reduce_multiblock_atomic_add_impl(bool do_verification, ...@@ -200,9 +219,8 @@ int reduce_multiblock_atomic_add_impl(bool do_verification,
if(!reduce.IsSupportedArgument(argument_ptr.get())) if(!reduce.IsSupportedArgument(argument_ptr.get()))
{ {
std::cerr std::cerr << "The runtime parameters not supported by the DeviceReduce instance, exiting!"
<< "The runtime parameters seems not supported by the DeviceReduce instance, exiting!" << std::endl;
<< std::endl;
return (-2); return (-2);
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -17,115 +17,11 @@ ...@@ -17,115 +17,11 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp"
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename AccDataType, typename ComputeDataType,
typename IndexDataType,
ck::ReduceTensorOp ReduceOpId,
bool PropagateNan,
bool OutputIndex>
static void pool_host_verify(const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
Tensor<IndexDataType>& out_indices,
const std::array<ck::index_t, 2>& window_spatial_lengths,
const std::array<ck::index_t, 2>& window_strides,
const std::array<ck::index_t, 2>& in_left_pads,
const std::array<ck::index_t, 2>& /*in_right_pads*/)
{
const int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1];
using ReduceOperation = typename ck::reduce_binary_operator<ReduceOpId>::opType;
auto elementwise_ops =
ck::reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(reduceLength);
auto in_elementwise_op = std::get<0>(elementwise_ops);
auto acc_elementwise_op = std::get<1>(elementwise_ops);
if constexpr(!OutputIndex)
{
using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
{
ck::index_t hi = ho * window_strides[0] + y - in_left_pads[0];
for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x)
{
ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1];
if(hi >= 0 && hi < static_cast<ck::index_t>(in.mDesc.GetLengths()[2]) &&
wi >= 0 && wi < static_cast<ck::index_t>(in.mDesc.GetLengths()[3]))
{
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal);
}
}
}
acc_elementwise_op(accuVal, accuVal);
out(n, c, ho, wo) = accuVal;
};
make_ParallelTensorFunctor(f_nchw,
out.mDesc.GetLengths()[0],
out.mDesc.GetLengths()[1],
out.mDesc.GetLengths()[2],
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
}
else
{
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
IndexDataType accuIndex = 0;
for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
{
ck::index_t hi = ho * window_strides[0] + y - in_left_pads[0];
for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x)
{
ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in.mDesc.GetLengths()[3])
{
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
IndexDataType currIndex = y * window_spatial_lengths[1] + x;
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
}
}
}
acc_elementwise_op(accuVal, accuVal);
out(n, c, ho, wo) = accuVal;
out_indices(n, c, ho, wo) = accuIndex;
};
make_ParallelTensorFunctor(f_nchw,
out.mDesc.GetLengths()[0],
out.mDesc.GetLengths()[1],
out.mDesc.GetLengths()[2],
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
};
}
template <typename InDataType,
typename OutDataType,
typename AccDataType,
typename IndexDataType, typename IndexDataType,
typename InLayout, typename InLayout,
typename OutLayout, typename OutLayout,
...@@ -150,9 +46,10 @@ bool pool_test(bool do_verification, ...@@ -150,9 +46,10 @@ bool pool_test(bool do_verification,
{ {
using DevicePoolFwdInstance = using DevicePoolFwdInstance =
ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C< ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<
InDataType, // InDataType InDataType, // InDataType
OutDataType, // OutDataType OutDataType, // OutDataType
AccDataType, // AccDataType IndexDataType, // IndexDataType
ComputeDataType, // ComputeDataType
ReduceOpId, ReduceOpId,
OutputIndex, OutputIndex,
64, // BlockSize 64, // BlockSize
...@@ -165,10 +62,10 @@ bool pool_test(bool do_verification, ...@@ -165,10 +62,10 @@ bool pool_test(bool do_verification,
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1; const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1;
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1; const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1;
const std::array<ck::index_t, 2> window_spatial_lengths{{Y, X}}; const std::vector<ck::index_t> window_spatial_lengths{Y, X};
const std::array<ck::index_t, 2> window_strides{{window_stride_h, window_stride_w}}; const std::vector<ck::index_t> window_strides{window_stride_h, window_stride_w};
const std::array<ck::index_t, 2> input_left_pads{{in_left_pad_h, in_left_pad_w}}; const std::vector<ck::index_t> input_left_pads{in_left_pad_h, in_left_pad_w};
const std::array<ck::index_t, 2> input_right_pads{{in_right_pad_h, in_right_pad_w}}; const std::vector<ck::index_t> input_right_pads{in_right_pad_h, in_right_pad_w};
// tensor layout // tensor layout
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
...@@ -219,14 +116,16 @@ bool pool_test(bool do_verification, ...@@ -219,14 +116,16 @@ bool pool_test(bool do_verification,
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()), static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()),
N, {N, C, Hi, Wi},
C, {Y, X},
std::array<ck::index_t, 2>{{Hi, Wi}}, {N, C, Ho, Wo},
std::array<ck::index_t, 2>{{Y, X}}, {C * Hi * Wi, 1, Wi * C, C},
std::array<ck::index_t, 2>{{Ho, Wo}}, {C * Ho * Wo, 1, Wo * C, C},
{C * Ho * Wo, 1, Wo * C, C},
window_strides, window_strides,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
{2, 3});
if(!pool.IsSupportedArgument(argument_ptr.get())) if(!pool.IsSupportedArgument(argument_ptr.get()))
{ {
...@@ -252,19 +151,28 @@ bool pool_test(bool do_verification, ...@@ -252,19 +151,28 @@ bool pool_test(bool do_verification,
if(do_verification) if(do_verification)
{ {
pool_host_verify<InDataType, using ReferencePoolingFwdInstance =
OutDataType, ck::tensor_operation::host::ReferencePoolingFwd<4,
AccDataType, 2,
IndexDataType, InDataType,
ReduceOpId, OutDataType,
PropagateNan, ComputeDataType,
OutputIndex>(in_n_c_hi_wi, IndexDataType,
out_n_c_ho_wo_host, ReduceOpId,
out_indices_n_c_ho_wo_host, PropagateNan,
window_spatial_lengths, OutputIndex>;
window_strides,
input_left_pads, auto ref_pooling = ReferencePoolingFwdInstance{};
input_right_pads); auto ref_pooling_invoker = ref_pooling.MakeInvoker();
auto ref_pooling_argument = ref_pooling.MakeArgument(in_n_c_hi_wi,
out_n_c_ho_wo_host,
out_indices_n_c_ho_wo_host,
window_spatial_lengths,
window_strides,
input_left_pads,
input_right_pads);
ref_pooling_invoker.Run(ref_pooling_argument);
out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data()); out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data());
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
...@@ -10,9 +9,9 @@ ...@@ -10,9 +9,9 @@
#include "pool2d_fwd_common.hpp" #include "pool2d_fwd_common.hpp"
using InDataType = ck::half_t; using InDataType = ck::half_t;
using OutDataType = ck::half_t; using OutDataType = ck::half_t;
using AccDataType = float; using ComputeDataType = float;
using IndexDataType = int32_t; using IndexDataType = int32_t;
...@@ -91,7 +90,7 @@ int main(int argc, char* argv[]) ...@@ -91,7 +90,7 @@ int main(int argc, char* argv[])
bool pass = pool_test<InDataType, bool pass = pool_test<InDataType,
OutDataType, OutDataType,
AccDataType, ComputeDataType,
IndexDataType, IndexDataType,
InLayout, InLayout,
OutLayout, OutLayout,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp" #include "ck/utility/reduction_enums.hpp"
...@@ -10,9 +9,9 @@ ...@@ -10,9 +9,9 @@
#include "pool2d_fwd_common.hpp" #include "pool2d_fwd_common.hpp"
using InDataType = float; using InDataType = float;
using OutDataType = float; using OutDataType = float;
using AccDataType = float; using ComputeDataType = float;
using IndexDataType = int32_t; using IndexDataType = int32_t;
...@@ -91,7 +90,7 @@ int main(int argc, char* argv[]) ...@@ -91,7 +90,7 @@ int main(int argc, char* argv[])
bool pass = pool_test<InDataType, bool pass = pool_test<InDataType,
OutDataType, OutDataType,
AccDataType, ComputeDataType,
IndexDataType, IndexDataType,
InLayout, InLayout,
OutLayout, OutLayout,
......
add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp) # dlops
add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp) add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp)
\ No newline at end of file
# xdlops
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp)
add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp)
set(target 1)
endif()
endforeach()
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using I8 = int8_t;
using I32 = int32_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using ActivationOp = PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<ActivationOp>;
using ADataType = I8;
using BDataType = I8;
using AccDataType = I32;
using CShuffleDataType = I32;
using DsDataType = ck::Tuple<>;
using EDataType = I8;
using ALayout = Row;
using BLayout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Dl<
ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AccDataType,
DsDataType,
EDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmDefault,
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // K0PerBlock
4, // K1
4, // M1PerThread
4, // N1PerThread
1, // KPerThread
S<8, 2>, // M1N1ThreadClusterM1Xs
S<8, 2>, // M1N1ThreadClusterN1Xs
S<8, 1, 1, 4>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1
S<2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1
S<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder
S<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder
S<4, 1, 1, 4>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
S<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder
S<1, 1, 1, 4>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
S<8, 1, 1, 4>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1
S<2, 1, 128, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1
S<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder
S<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder
S<4, 1, 1, 4>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
S<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder
S<1, 1, 1, 4>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
5, // CThreadTransferSrcDstVectorDim
4>; // CThreadTransferDstScalarPerVector
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, EDataType, float, PassThrough, PassThrough, CDEElementOp>;
int main()
{
bool do_verification = true;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 1024;
ck::index_t StrideA = 1024;
ck::index_t StrideB = 1024;
ck::index_t StrideE = 1024;
float requant_scale = 0.03;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1_uz}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1_uz, stride}));
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{requant_scale, ActivationOp{}};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
{},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
{},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, e_m_n_host_result, a_element_op, b_element_op, cde_element_op);
ref_invoker.Run(ref_argument);
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
}
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
......
...@@ -4,12 +4,17 @@ add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp) ...@@ -4,12 +4,17 @@ add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp)
add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp) add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp)
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp)
add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp)
add_dependencies(example_grouped_gemm_xdl add_dependencies(example_grouped_gemm_xdl
example_grouped_gemm_xdl_fp32 example_grouped_gemm_xdl_fp32
example_grouped_gemm_xdl_fp16 example_grouped_gemm_xdl_fp16
example_grouped_gemm_xdl_bfp16 example_grouped_gemm_xdl_bfp16
example_grouped_gemm_xdl_int8) example_grouped_gemm_xdl_int8
example_grouped_gemm_multiple_d_dl_fp16
example_grouped_gemm_xdl_splitk_fp16)
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstddef>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <stdexcept>
#include <string>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/sequence.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = F16;
using ALayout = Row;
using BLayout = Row;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::
// ##################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ##################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ##################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmMultipleD_Dl< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>;
// clang-format on
#include "run_grouped_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment