Commit 4fe6f9ae authored by ltqin's avatar ltqin
Browse files

change arch

parent 5c7dfc7a
......@@ -38,29 +38,14 @@ enum struct ArchitectureEnum
Gfx940,
Gfx1030
};
enum struct ArchFeatureEnum
{
None,
Xdl,
Dl,
Wmma
};
template <ArchitectureEnum... Is>
struct ArchitectureEnumSequence
{
static constexpr int mSize = sizeof...(Is);
__host__ __device__ static constexpr ArchitectureEnum At(int I)
{
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
const ArchitectureEnum mData[mSize + 1] = {Is..., ArchitectureEnum::None};
return mData[I];
}
template <ArchitectureEnum Arch, typename DeviceOp>
struct DeviceOperationInstances
{
static auto get_device_instances() { return std::tuple<>{}; }
};
template <ArchFeatureEnum Feature, typename DeviceOp>
struct DeviceOperationInstances;
template <typename Arch, typename DeviceOp>
template <ArchitectureEnum Arch, typename DeviceOp>
struct DeviceOperationInstanceCreator;
} // namespace instance
} // namespace device
......
......@@ -18,7 +18,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
template <typename Arch,
template <ArchitectureEnum Arch,
index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
......@@ -74,11 +74,8 @@ struct DeviceOperationInstanceCreator<Arch,
MaskingSpec>;
static void add_device_instances(std::vector<std::unique_ptr<DeviceOp>>& instances)
{
if constexpr(DeviceOperationInstances<ArchFeatureEnum::Xdl,
DeviceOp>::template is_surport<Arch>())
add_device_operation_instances(
instances,
DeviceOperationInstances<ArchFeatureEnum::Xdl, DeviceOp>::get_device_instances());
add_device_operation_instances(
instances, DeviceOperationInstances<Arch, DeviceOp>::get_device_instances());
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename C0DEElementwiseOperation,
typename B1ElementwiseOperation,
typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
B0ElementwiseOperation,
C0DEElementwiseOperation,
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>>
{
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
B0ElementwiseOperation,
C0DEElementwiseOperation,
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if(ck::get_device_name() == "gfx908")
{
DeviceOperationInstanceCreator<ArchitectureEnum::Gfx908,
DeviceOp>::add_device_instances(op_ptrs);
}
else if(ck::get_device_name() == "gfx90a")
{
DeviceOperationInstanceCreator<ArchitectureEnum::Gfx90a,
DeviceOp>::add_device_instances(op_ptrs);
}
else if(ck::get_device_name() == "gfx940")
{
DeviceOperationInstanceCreator<ArchitectureEnum::Gfx940,
DeviceOp>::add_device_instances(op_ptrs);
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -66,8 +66,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceCreator<ArchitectureEnumSequence<ArchitectureEnum::Gfx908>,
DeviceOp>::add_device_instances(instances);
DeviceOperationInstanceCreator<ArchitectureEnum::Gfx908, DeviceOp>::add_device_instances(
instances);
}
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
......@@ -108,8 +108,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceCreator<ArchitectureEnumSequence<ArchitectureEnum::Gfx908>,
DeviceOp>::add_device_instances(instances);
DeviceOperationInstanceCreator<ArchitectureEnum::Gfx908, DeviceOp>::add_device_instances(
instances);
}
} // namespace instance
......
......@@ -65,8 +65,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceCreator<ArchitectureEnumSequence<ArchitectureEnum::Gfx908>,
DeviceOp>::add_device_instances(instances);
DeviceOperationInstanceCreator<ArchitectureEnum::Gfx908, DeviceOp>::add_device_instances(
instances);
}
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
......@@ -107,8 +107,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceCreator<ArchitectureEnumSequence<ArchitectureEnum::Gfx908>,
DeviceOp>::add_device_instances(instances);
DeviceOperationInstanceCreator<ArchitectureEnum::Gfx908, DeviceOp>::add_device_instances(
instances);
}
} // namespace instance
......
......@@ -65,8 +65,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceCreator<ArchitectureEnumSequence<ArchitectureEnum::Gfx908>,
DeviceOp>::add_device_instances(instances);
DeviceOperationInstanceCreator<ArchitectureEnum::Gfx908, DeviceOp>::add_device_instances(
instances);
}
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
......@@ -107,8 +107,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceCreator<ArchitectureEnumSequence<ArchitectureEnum::Gfx908>,
DeviceOp>::add_device_instances(instances);
DeviceOperationInstanceCreator<ArchitectureEnum::Gfx908, DeviceOp>::add_device_instances(
instances);
}
} // namespace instance
......
......@@ -65,8 +65,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceCreator<ArchitectureEnumSequence<ArchitectureEnum::Gfx908>,
DeviceOp>::add_device_instances(instances);
DeviceOperationInstanceCreator<ArchitectureEnum::Gfx908, DeviceOp>::add_device_instances(
instances);
}
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<
......@@ -106,8 +106,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceCreator<ArchitectureEnumSequence<ArchitectureEnum::Gfx908>,
DeviceOp>::add_device_instances(instances);
DeviceOperationInstanceCreator<ArchitectureEnum::Gfx908, DeviceOp>::add_device_instances(
instances);
}
} // namespace instance
......
......@@ -5,7 +5,7 @@
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute_general.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
......@@ -134,11 +134,8 @@ int main()
MaskingSpec>;
// get device op instances
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
ck::tensor_operation::device::instance::DeviceOperationInstanceCreator<
ck::tensor_operation::device::instance::ArchitectureEnumSequence<
ck::tensor_operation::device::instance::ArchitectureEnum::All>,
DeviceOp>::add_device_instances(op_ptrs);
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment