"megatron/git@developer.sourcefind.cn:wuxk1/megatron-lm.git" did not exist on "27ecc17a1ea93b5b6b68145df06094de0aa53356"
Commit 4fe6f9ae authored by ltqin's avatar ltqin
Browse files

change arch

parent 5c7dfc7a
...@@ -38,29 +38,14 @@ enum struct ArchitectureEnum ...@@ -38,29 +38,14 @@ enum struct ArchitectureEnum
Gfx940, Gfx940,
Gfx1030 Gfx1030
}; };
enum struct ArchFeatureEnum
{
None,
Xdl,
Dl,
Wmma
};
template <ArchitectureEnum... Is>
struct ArchitectureEnumSequence
{
static constexpr int mSize = sizeof...(Is);
__host__ __device__ static constexpr ArchitectureEnum At(int I) template <ArchitectureEnum Arch, typename DeviceOp>
{ struct DeviceOperationInstances
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0 {
const ArchitectureEnum mData[mSize + 1] = {Is..., ArchitectureEnum::None}; static auto get_device_instances() { return std::tuple<>{}; }
return mData[I];
}
}; };
template <ArchFeatureEnum Feature, typename DeviceOp>
struct DeviceOperationInstances;
template <typename Arch, typename DeviceOp> template <ArchitectureEnum Arch, typename DeviceOp>
struct DeviceOperationInstanceCreator; struct DeviceOperationInstanceCreator;
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -18,7 +18,7 @@ namespace tensor_operation { ...@@ -18,7 +18,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
template <typename Arch, template <ArchitectureEnum Arch,
index_t NumDimG, index_t NumDimG,
index_t NumDimM, index_t NumDimM,
index_t NumDimN, index_t NumDimN,
...@@ -74,11 +74,8 @@ struct DeviceOperationInstanceCreator<Arch, ...@@ -74,11 +74,8 @@ struct DeviceOperationInstanceCreator<Arch,
MaskingSpec>; MaskingSpec>;
static void add_device_instances(std::vector<std::unique_ptr<DeviceOp>>& instances) static void add_device_instances(std::vector<std::unique_ptr<DeviceOp>>& instances)
{ {
if constexpr(DeviceOperationInstances<ArchFeatureEnum::Xdl, add_device_operation_instances(
DeviceOp>::template is_surport<Arch>()) instances, DeviceOperationInstances<Arch, DeviceOp>::get_device_instances());
add_device_operation_instances(
instances,
DeviceOperationInstances<ArchFeatureEnum::Xdl, DeviceOp>::get_device_instances());
} }
}; };
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename C0DEElementwiseOperation,
typename B1ElementwiseOperation,
typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
B0ElementwiseOperation,
C0DEElementwiseOperation,
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>>
{
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
B0ElementwiseOperation,
C0DEElementwiseOperation,
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if(ck::get_device_name() == "gfx908")
{
DeviceOperationInstanceCreator<ArchitectureEnum::Gfx908,
DeviceOp>::add_device_instances(op_ptrs);
}
else if(ck::get_device_name() == "gfx90a")
{
DeviceOperationInstanceCreator<ArchitectureEnum::Gfx90a,
DeviceOp>::add_device_instances(op_ptrs);
}
else if(ck::get_device_name() == "gfx940")
{
DeviceOperationInstanceCreator<ArchitectureEnum::Gfx940,
DeviceOp>::add_device_instances(op_ptrs);
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -66,8 +66,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -66,8 +66,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>; MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceCreator<ArchitectureEnumSequence<ArchitectureEnum::Gfx908>, DeviceOperationInstanceCreator<ArchitectureEnum::Gfx908, DeviceOp>::add_device_instances(
DeviceOp>::add_device_instances(instances); instances);
} }
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
...@@ -108,8 +108,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -108,8 +108,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>; MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceCreator<ArchitectureEnumSequence<ArchitectureEnum::Gfx908>, DeviceOperationInstanceCreator<ArchitectureEnum::Gfx908, DeviceOp>::add_device_instances(
DeviceOp>::add_device_instances(instances); instances);
} }
} // namespace instance } // namespace instance
......
...@@ -65,8 +65,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -65,8 +65,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>; MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceCreator<ArchitectureEnumSequence<ArchitectureEnum::Gfx908>, DeviceOperationInstanceCreator<ArchitectureEnum::Gfx908, DeviceOp>::add_device_instances(
DeviceOp>::add_device_instances(instances); instances);
} }
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
...@@ -107,8 +107,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -107,8 +107,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>; MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceCreator<ArchitectureEnumSequence<ArchitectureEnum::Gfx908>, DeviceOperationInstanceCreator<ArchitectureEnum::Gfx908, DeviceOp>::add_device_instances(
DeviceOp>::add_device_instances(instances); instances);
} }
} // namespace instance } // namespace instance
......
...@@ -65,8 +65,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -65,8 +65,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>; MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceCreator<ArchitectureEnumSequence<ArchitectureEnum::Gfx908>, DeviceOperationInstanceCreator<ArchitectureEnum::Gfx908, DeviceOp>::add_device_instances(
DeviceOp>::add_device_instances(instances); instances);
} }
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
...@@ -107,8 +107,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -107,8 +107,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>; MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceCreator<ArchitectureEnumSequence<ArchitectureEnum::Gfx908>, DeviceOperationInstanceCreator<ArchitectureEnum::Gfx908, DeviceOp>::add_device_instances(
DeviceOp>::add_device_instances(instances); instances);
} }
} // namespace instance } // namespace instance
......
...@@ -65,8 +65,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -65,8 +65,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>; MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceCreator<ArchitectureEnumSequence<ArchitectureEnum::Gfx908>, DeviceOperationInstanceCreator<ArchitectureEnum::Gfx908, DeviceOp>::add_device_instances(
DeviceOp>::add_device_instances(instances); instances);
} }
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector< std::vector<
...@@ -106,8 +106,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -106,8 +106,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>; MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceCreator<ArchitectureEnumSequence<ArchitectureEnum::Gfx908>, DeviceOperationInstanceCreator<ArchitectureEnum::Gfx908, DeviceOp>::add_device_instances(
DeviceOp>::add_device_instances(instances); instances);
} }
} // namespace instance } // namespace instance
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <vector> #include <vector>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute_general.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...@@ -134,11 +134,8 @@ int main() ...@@ -134,11 +134,8 @@ int main()
MaskingSpec>; MaskingSpec>;
// get device op instances // get device op instances
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
ck::tensor_operation::device::instance::DeviceOperationInstanceCreator< DeviceOp>::GetInstances();
ck::tensor_operation::device::instance::ArchitectureEnumSequence<
ck::tensor_operation::device::instance::ArchitectureEnum::All>,
DeviceOp>::add_device_instances(op_ptrs);
std::cout << "found " << op_ptrs.size() << " instances" << std::endl; std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment