Commit 0b11569f authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into batched_gemm_c_permute

parents e8d3a0fb fa9a0a5c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
......@@ -23,12 +26,11 @@ struct ReferenceSoftmax : public device::BaseOperator
Tensor<OutDataType>& out,
AccDataType alpha,
AccDataType beta,
const index_t rank,
const std::vector<index_t> sm_reduce_dims)
: in_(in), out_(out), alpha_(alpha), beta_(beta), sm_reduce_dims_(sm_reduce_dims)
{
// std::cout << "debug: scalar dims: ";
for(int i = 0; i < rank; i++)
for(size_t i = 0; i < in.mDesc.GetNumOfDimension(); i++)
{
if(std::find(sm_reduce_dims.begin(), sm_reduce_dims.end(), i) ==
sm_reduce_dims.end())
......@@ -44,7 +46,6 @@ struct ReferenceSoftmax : public device::BaseOperator
Tensor<OutDataType>& out_;
AccDataType alpha_;
AccDataType beta_;
index_t rank_;
std::vector<index_t> sm_reduce_dims_;
std::vector<index_t> sm_scalar_dims_; // dim after internal max/sum reduction
};
......@@ -133,10 +134,9 @@ struct ReferenceSoftmax : public device::BaseOperator
Tensor<OutDataType>& out,
AccDataType alpha,
AccDataType beta,
const index_t rank,
const std::vector<index_t> sm_reduce_dims)
{
return Argument{in, out, alpha, beta, rank, sm_reduce_dims};
return Argument{in, out, alpha, beta, sm_reduce_dims};
}
static auto MakeInvoker() { return Invoker{}; }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef NAIVE_CONV_FWD_HPP
#define NAIVE_CONV_FWD_HPP
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include "ck/utility/functional2.hpp"
namespace ck {
namespace tensor_operation {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
using DeviceBatchedGemmNoOpPtr = ck::tensor_operation::device::DeviceBatchedGemmPtr<
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances(
std::vector<DeviceBatchedGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(
std::vector<DeviceBatchedGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(
std::vector<DeviceBatchedGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances(
std::vector<DeviceBatchedGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(
std::vector<DeviceBatchedGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(
std::vector<DeviceBatchedGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(
std::vector<DeviceBatchedGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(
std::vector<DeviceBatchedGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances(
std::vector<DeviceBatchedGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(
std::vector<DeviceBatchedGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(
std::vector<DeviceBatchedGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(
std::vector<DeviceBatchedGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances(
std::vector<DeviceBatchedGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances(
std::vector<DeviceBatchedGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances(
std::vector<DeviceBatchedGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances(
std::vector<DeviceBatchedGemmNoOpPtr>&);
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
auto get_device_batched_gemm_instances()
{
std::vector<DeviceBatchedGemmNoOpPtr> op_ptrs;
if constexpr(is_same<ADataType, float>::value && is_same<BDataType, float>::value &&
is_same<CDataType, float>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(op_ptrs);
}
}
else if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(op_ptrs);
}
}
else if constexpr(is_same<ADataType, bhalf_t>::value && is_same<BDataType, bhalf_t>::value &&
is_same<CDataType, bhalf_t>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances(op_ptrs);
}
}
else if constexpr(is_same<ADataType, int8_t>::value && is_same<BDataType, int8_t>::value &&
is_same<CDataType, int8_t>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances(op_ptrs);
}
}
return op_ptrs;
}
} // namespace device_batched_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
using Normalize = ck::tensor_operation::element_wise::Normalize;
using DeviceNormalizeFromMeanMeanSquarePtr =
ck::tensor_operation::device::DeviceElementwisePtr<5, 1, 2, Normalize>;
void add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances(
std::vector<DeviceNormalizeFromMeanMeanSquarePtr>& instances);
template <typename InputType,
typename MeanType,
typename MeanSquareType,
typename GammaDataType,
typename BetaDataType,
typename OutputType>
auto get_device_normalize_from_mean_meansquare_instances()
{
std::vector<DeviceNormalizeFromMeanMeanSquarePtr> op_ptrs;
if constexpr(is_same<InputType, half_t>::value && is_same<MeanType, float>::value &&
is_same<MeanSquareType, float>::value && is_same<GammaDataType, half_t>::value &&
is_same<BetaDataType, half_t>::value && is_same<OutputType, half_t>::value)
{
ck::tensor_operation::device::
add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances(op_ptrs);
}
return op_ptrs;
}
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using DeviceGemmAddAddFastGeluPtr = ck::tensor_operation::device::DeviceGemmMultipleDPtr<
2,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::AddAddFastGelu>;
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGemmAddAddFastGeluPtr>&);
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGemmAddAddFastGeluPtr>&);
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
std::vector<DeviceGemmAddAddFastGeluPtr>&);
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGemmAddAddFastGeluPtr>&);
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename D0DataType,
typename D1DataType,
typename EDataType,
typename ALayout,
typename BLayout,
typename D0Layout,
typename D1Layout,
typename ELayout>
auto get_device_gemm_add_add_fastgelu_instances()
{
std::vector<DeviceGemmAddAddFastGeluPtr> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor> &&
is_same_v<BLayout, tensor_layout::gemm::RowMajor> &&
is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor> &&
is_same_v<BLayout, tensor_layout::gemm::ColumnMajor> &&
is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, tensor_layout::gemm::ColumnMajor> &&
is_same_v<BLayout, tensor_layout::gemm::RowMajor> &&
is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, tensor_layout::gemm::ColumnMajor> &&
is_same_v<BLayout, tensor_layout::gemm::ColumnMajor> &&
is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
op_ptrs);
}
}
return op_ptrs;
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
auto get_device_gemm_instances()
{
std::vector<DeviceGemmNoOpPtr> op_ptrs;
if constexpr(is_same<ADataType, float>::value && is_same<BDataType, float>::value &&
is_same<CDataType, float>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(op_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(op_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(op_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(op_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(op_ptrs);
}
}
else if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
else if constexpr(is_same<ADataType, ck::bhalf_t>::value &&
is_same<BDataType, ck::bhalf_t>::value &&
is_same<CDataType, ck::bhalf_t>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(op_ptrs);
}
}
else if constexpr(is_same<ADataType, int8_t>::value && is_same<BDataType, int8_t>::value &&
is_same<CDataType, int8_t>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(op_ptrs);
}
}
return op_ptrs;
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using DeviceGemmAddAddMeanSquareMeanPtr = ck::tensor_operation::device::DeviceGemmReducePtr<1, 2>;
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
auto get_device_gemm_add_add_mean_squaremean_instances()
{
std::vector<DeviceGemmAddAddMeanSquareMeanPtr> op_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
op_ptrs);
}
}
return op_ptrs;
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using DeviceGemmSplitKNoOpPtr = ck::tensor_operation::device::DeviceGemmSplitKPtr<
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmSplitKNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(
std::vector<DeviceGemmSplitKNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmSplitKNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmSplitKNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGemmSplitKNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGemmSplitKNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(
std::vector<DeviceGemmSplitKNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGemmSplitKNoOpPtr>&);
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
auto get_device_gemm_splitk_instances()
{
std::vector<DeviceGemmSplitKNoOpPtr> op_ptrs;
if constexpr(is_same<ADataType, float>::value && is_same<BDataType, float>::value &&
is_same<CDataType, float>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(op_ptrs);
}
}
else if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
return op_ptrs;
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment