Commit 8a370fbb authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into group_conv

parents d8fdd226 85978e02
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
struct GemmDesc
{
ck::index_t M_, N_, K_;
ck::index_t stride_A_, stride_B_, stride_C_;
std::vector<ck::index_t> stride_Ds_;
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemm : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsisiten NumDTensor");
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*>& p_a,
std::vector<const void*>& p_b,
std::vector<std::array<const void*, NumDTensor>>& p_ds,
std::vector<void*>& p_e,
std::vector<GemmDesc>& gemm_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using DsType = Tuple<>;
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Row,
F16,
F16,
DsType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Row,
F16,
F16,
DsType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row,
Row,
F16,
F16,
DsType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Col,
Row,
F16,
F16,
DsType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedGemm<
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
DsDataType,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceGroupedGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
DsDataType,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -31,9 +31,8 @@ check_err(const std::vector<T>& out, ...@@ -31,9 +31,8 @@ check_err(const std::vector<T>& out,
{ {
if(out.size() != ref.size()) if(out.size() != ref.size())
{ {
std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() std::cout << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl << std::endl;
<< msg << std::endl;
return false; return false;
} }
...@@ -50,9 +49,8 @@ check_err(const std::vector<T>& out, ...@@ -50,9 +49,8 @@ check_err(const std::vector<T>& out,
err_count++; err_count++;
if(err_count < 5) if(err_count < 5)
{ {
std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" std::cout << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< i << "]: " << out[i] << " != " << ref[i] << std::endl << "] != ref[" << i << "]: " << out[i] << " != " << ref[i] << std::endl;
<< msg << std::endl;
} }
res = false; res = false;
} }
...@@ -74,9 +72,8 @@ check_err(const std::vector<T>& out, ...@@ -74,9 +72,8 @@ check_err(const std::vector<T>& out,
{ {
if(out.size() != ref.size()) if(out.size() != ref.size())
{ {
std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() std::cout << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl << std::endl;
<< msg << std::endl;
return false; return false;
} }
...@@ -96,9 +93,8 @@ check_err(const std::vector<T>& out, ...@@ -96,9 +93,8 @@ check_err(const std::vector<T>& out,
err_count++; err_count++;
if(err_count < 5) if(err_count < 5)
{ {
std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" std::cout << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< i << "]: " << o << " != " << r << std::endl << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
<< msg << std::endl;
} }
res = false; res = false;
} }
...@@ -120,9 +116,8 @@ check_err(const std::vector<T>& out, ...@@ -120,9 +116,8 @@ check_err(const std::vector<T>& out,
{ {
if(out.size() != ref.size()) if(out.size() != ref.size())
{ {
std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() std::cout << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl << std::endl;
<< msg << std::endl;
return false; return false;
} }
...@@ -141,9 +136,8 @@ check_err(const std::vector<T>& out, ...@@ -141,9 +136,8 @@ check_err(const std::vector<T>& out,
err_count++; err_count++;
if(err_count < 5) if(err_count < 5)
{ {
std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" std::cout << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< i << "]: " << o << " != " << r << std::endl << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
<< msg << std::endl;
} }
res = false; res = false;
} }
...@@ -165,9 +159,8 @@ check_err(const std::vector<T>& out, ...@@ -165,9 +159,8 @@ check_err(const std::vector<T>& out,
{ {
if(out.size() != ref.size()) if(out.size() != ref.size())
{ {
std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() std::cout << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl << std::endl;
<< msg << std::endl;
return false; return false;
} }
...@@ -187,9 +180,9 @@ check_err(const std::vector<T>& out, ...@@ -187,9 +180,9 @@ check_err(const std::vector<T>& out,
err_count++; err_count++;
if(err_count < 5) if(err_count < 5)
{ {
std::cout << "out[" << i << "] != ref[" << i << "]: " << static_cast<int>(out[i]) std::cout << msg << " out[" << i << "] != ref[" << i
<< " != " << static_cast<int>(ref[i]) << std::endl << "]: " << static_cast<int>(out[i]) << " != " << static_cast<int>(ref[i])
<< msg << std::endl; << std::endl;
} }
res = false; res = false;
} }
......
...@@ -433,52 +433,3 @@ struct Tensor ...@@ -433,52 +433,3 @@ struct Tensor
HostTensorDescriptor mDesc; HostTensorDescriptor mDesc;
std::vector<T> mData; std::vector<T> mData;
}; };
#if 1
// FIXME: remove
template <typename T>
float check_error(const Tensor<T>& ref, const Tensor<T>& result)
{
float l1_error = 0;
float linf_error = -1;
float linf_rel_error = -1;
float linf_ref_value = 0, linf_result_value = 0;
float linf_rel_ref_value = 0, linf_rel_result_value = 0;
constexpr float eps = 1e-10;
for(std::size_t i = 0; i < ref.mData.size(); ++i)
{
float ref_v = ck::type_convert<float>(ref.mData[i]);
float result_v = ck::type_convert<float>(result.mData[i]);
float diff = std::abs(ref_v - result_v);
float rel_diff = diff / std::max(std::abs(ref_v), eps);
l1_error += diff;
if(linf_error < diff)
{
linf_error = diff;
linf_ref_value = ref_v;
linf_result_value = result_v;
}
if(linf_rel_error < rel_diff)
{
linf_rel_error = rel_diff;
linf_rel_ref_value = ref_v;
linf_rel_result_value = result_v;
}
}
std::cout << "Absolute Error L1 Norm (sum of abs diff): " << l1_error << std::endl;
std::cout << "Absolute Error L-inf Norm (max abs diff): " << linf_error << ", ref "
<< linf_ref_value << ", result " << linf_result_value << std::endl;
std::cout << "Relative Error L-inf Norm (max relative abs diff): " << linf_rel_error << ", ref "
<< linf_rel_ref_value << ", result " << linf_rel_result_value << std::endl;
return linf_error;
}
#endif
...@@ -23,6 +23,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -23,6 +23,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using DsType = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
...@@ -30,23 +32,40 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -30,23 +32,40 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// Compilation parameters for a[k, m] * b[k, n] = c[m, n] // Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances = std::tuple< using device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances = std::tuple<
// clang-format off // clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //##################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1> DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< Col, Row, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on // clang-format on
>; >;
void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
std::vector<DeviceGroupedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row,
Row,
F16,
F16,
DsType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances{}); device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances{});
......
...@@ -23,30 +23,48 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -23,30 +23,48 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using DsType = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances = std::tuple< using device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances = std::tuple<
// clang-format off // clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //##################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdl< Col, Col, Row, F16, F16, F32, F16, DsType, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on // clang-format on
>; >;
void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGroupedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Col,
Row,
F16,
F16,
DsType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances{}); device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances{});
......
...@@ -318,13 +318,16 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -318,13 +318,16 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data()); reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data());
reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data()); reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data());
float c_error = check_error(c_g_m_n_host_result, c_g_m_n_device_result); bool c_error =
float d0_error = check_error(d0_g_m_host_result, d0_g_m_device_result); ck::utils::check_err(c_g_m_n_host_result.mData, c_g_m_n_device_result.mData);
float d1_error = check_error(d1_g_m_host_result, d1_g_m_device_result); bool d0_error =
ck::utils::check_err(d0_g_m_host_result.mData, d0_g_m_device_result.mData);
pass = pass && (c_error < 1E-6); bool d1_error =
pass = pass && (d0_error < 1E-6); ck::utils::check_err(d1_g_m_host_result.mData, d1_g_m_device_result.mData);
pass = pass && (d1_error < 1E-6);
pass = pass && (c_error == true);
pass = pass && (d0_error == true);
pass = pass && (d1_error == true);
if(do_log) if(do_log)
{ {
......
...@@ -159,7 +159,7 @@ bool profile_conv_bwd_weight_impl(int do_verification, ...@@ -159,7 +159,7 @@ bool profile_conv_bwd_weight_impl(int do_verification,
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
// profile device Conv instances // profile device Conv instances
bool pass = true; bool all_pass = true;
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
...@@ -215,8 +215,15 @@ bool profile_conv_bwd_weight_impl(int do_verification, ...@@ -215,8 +215,15 @@ bool profile_conv_bwd_weight_impl(int do_verification,
{ {
wei_device_buf.FromDevice(weight_device_result.mData.data()); wei_device_buf.FromDevice(weight_device_result.mData.data());
pass = pass & bool pass = ck::utils::check_err(wei_k_c_y_x_host_result.mData,
ck::utils::check_err(weight_device_result.mData, weight_host_result.mData); wei_k_c_y_x_device_result.mData);
if(!pass)
{
std::cout << "Fail info:" << conv_ptr->GetTypeString() << std::endl;
}
all_pass &= pass;
if(do_log) if(do_log)
{ {
...@@ -248,7 +255,7 @@ bool profile_conv_bwd_weight_impl(int do_verification, ...@@ -248,7 +255,7 @@ bool profile_conv_bwd_weight_impl(int do_verification,
<< "\nname: " << best_op_name << "\navg_time: " << best_avg_time << "\nname: " << best_op_name << "\navg_time: " << best_avg_time
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl;
return pass; return all_pass;
} }
} // namespace profiler } // namespace profiler
......
This diff is collapsed.
This diff is collapsed.
File mode changed from 100644 to 100755
...@@ -85,7 +85,6 @@ def parse_logfile(logfile): ...@@ -85,7 +85,6 @@ def parse_logfile(logfile):
for line in open(logfile): for line in open(logfile):
if 'Best Perf' in line: if 'Best Perf' in line:
lst=line.split() lst=line.split()
print("len(lst)=",len(lst),"lst:",lst)
if len(lst)>=37: #the line is complete if len(lst)>=37: #the line is complete
tests.append(glue.join(lst[5:30])) tests.append(glue.join(lst[5:30]))
kernels.append(glue.join(lst[37:])) kernels.append(glue.join(lst[37:]))
...@@ -293,4 +292,4 @@ def main(): ...@@ -293,4 +292,4 @@ def main():
return regression return regression
if __name__ == '__main__': if __name__ == '__main__':
main() main()
\ No newline at end of file
#!/bin/bash
#
# in order to run this script you'd need the following python packages:
pip3 install --upgrade pip
pip3 install sqlalchemy pymysql pandas sshtunnel
# you would also need to set up some environment variables in order to
# post your new test results to the database and compare them to the baseline
# please contact Illia.Silin@amd.com for more details
#process results
gpu_arch=$1
python3 process_perf_data.py perf_gemm_"$gpu_arch".log
python3 process_perf_data.py perf_resnet50_N265_"$gpu_arch".log
python3 process_perf_data.py perf_resnet50_N4_"$gpu_arch".log
\ No newline at end of file
#!/bin/bash
#
# in order to run this script you'd need the following python packages:
pip3 install --upgrade pip
pip3 install sqlalchemy pymysql pandas sshtunnel
# you would also need to set up some environment variables in order to
# post your new test results to the database and compare them to the baseline
# please contact Illia.Silin@amd.com for more details
#process results
gpu_arch=$1
python3 process_perf_data.py perf_gemm_"$gpu_arch".log
python3 process_perf_data.py perf_resnet50_N265_"$gpu_arch".log
python3 process_perf_data.py perf_resnet50_N4_"$gpu_arch".log
python3 process_perf_data.py perf_batched_gemm_"$gpu_arch".log
python3 process_perf_data.py perf_grouped_gemm_"$gpu_arch".log
python3 process_perf_data.py perf_fwd_conv_"$gpu_arch".log
python3 process_perf_data.py perf_bwd_conv_"$gpu_arch".log
python3 process_perf_data.py perf_fusion_"$gpu_arch".log
python3 process_perf_data.py perf_reduction_"$gpu_arch".log
\ No newline at end of file
...@@ -11,26 +11,34 @@ INIT=$5 ...@@ -11,26 +11,34 @@ INIT=$5
LOG=$6 LOG=$6
REPEAT=$7 REPEAT=$7
######## op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC BatchCount OP=$1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 960 1024 1024 -1 -1 -1 8 DATATYPE=$2
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 -1 -1 -1 8 LAYOUT=$3
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -1 -1 -1 4 VERIFY=$4
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 -1 -1 -1 2 INIT=$5
LOG=$6
REPEAT=$7
######## op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC BatchStrideA BatchStrideB BatchStrideC BatchCount
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 960 1024 1024 -1 -1 -1 -1 -1 -1 8
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 -1 -1 -1 -1 -1 -1 8
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -1 -1 -1 -1 -1 -1 4
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 -1 -1 -1 -1 -1 -1 2
####### op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC BatchCount ####### op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC BatchStrideA BatchStrideB BatchStrideC BatchCount
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024 8 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024 -1 -1 -1 8
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048 8 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048 -1 -1 -1 8
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4096 4096 4096 4 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4096 4096 4096 -1 -1 -1 4
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8192 8192 8192 2 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8192 8192 8192 -1 -1 -1 2
####### op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC BatchCount ####### op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC BatchStrideA BatchStrideB BatchStrideC BatchCount
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1056 1056 1056 8 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1056 1056 1056 -1 -1 -1 8
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2080 2080 2080 8 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2080 2080 2080 -1 -1 -1 8
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4128 4128 4128 4 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4128 4128 4128 -1 -1 -1 4
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8224 8224 8224 2 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8224 8224 8224 -1 -1 -1 2
####### op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC BatchCount ####### op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC BatchStrideA BatchStrideB BatchStrideC BatchCount
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1088 1088 8 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1088 1088 -1 -1 -1 8
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112 8 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112 -1 -1 -1 8
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160 4 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160 -1 -1 -1 4
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256 2 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256 -1 -1 -1 2
\ No newline at end of file \ No newline at end of file
#!/bin/bash
## GPU visibility
export HIP_VISIBLE_DEVICES=0
DRIVER="../build/bin/ckProfiler"
OP=$1
DATATYPE=$2
LAYOUT=$3
VERIFY=$4
INIT=$5
LOG=$6
TIME=$7
######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideD StrideE Alpha Beta
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 960 1024 1024 -1 -1 -1 -1 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1920 2048 2048 -1 -1 -1 -1 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 3840 4096 4096 -1 -1 -1 -1 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 7680 8192 8192 -1 -1 -1 -1 1 1
######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideD StrideE Alpha Beta
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 960 1024 1024 -1 -1 0 -1 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1920 2048 2048 -1 -1 0 -1 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 3840 4096 4096 -1 -1 0 -1 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 7680 8192 8192 -1 -1 0 -1 1 1
######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideD StrideE Alpha Beta
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1000 1000 1000 -1 -1 0 -1 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2000 2000 2000 -1 -1 0 -1 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 4000 4000 4000 -1 -1 0 -1 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 8000 8000 8000 -1 -1 0 -1 1 1
######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideD StrideE Alpha Beta
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1024 1024 1024 1056 1056 1056 1056 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2048 2048 2048 2080 2080 2080 2080 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 4096 4096 4096 4128 4128 4128 4128 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 8192 8192 8192 8224 8224 8224 8224 1 1
######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideD StrideE Alpha Beta
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 1024 1024 1024 1088 1088 1088 1088 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2048 2048 2048 2112 2112 2112 2112 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 4096 4096 4096 4160 4160 4160 4160 1 1
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 8192 8192 8192 8256 8256 8256 8256 1 1
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment