Commit 7b73260c authored by ltqin's avatar ltqin
Browse files

add architecture

parent 7f632d63
add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute.cpp) add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute.cpp)
add_example_executable(example_gemm_bias_softmax_gemm_permute_nolib gemm_bias_softmax_gemm_permute_nolib.cpp)
...@@ -29,7 +29,12 @@ void add_device_operation_instances(std::vector<std::unique_ptr<BaseOp>>& op_ins ...@@ -29,7 +29,12 @@ void add_device_operation_instances(std::vector<std::unique_ptr<BaseOp>>& op_ins
}); });
} }
template <typename DeviceOp, typename Tag = void> enum struct ArchitectureEnum
{
Xdl,
Dl
};
template <typename DeviceOp, ArchitectureEnum Arch = ArchitectureEnum::Xdl>
struct DeviceOperationInstanceCreator; struct DeviceOperationInstanceCreator;
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -34,7 +34,8 @@ template <index_t NumDimG, ...@@ -34,7 +34,8 @@ template <index_t NumDimG,
typename C0DEElementwiseOperation, typename C0DEElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename C1DEElementwiseOperation, typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec> MaskingSpecialization MaskingSpec,
ArchitectureEnum Arch = ArchitectureEnum::Xdl>
void add_device_instances( void add_device_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG, std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM, NumDimM,
...@@ -72,7 +73,7 @@ void add_device_instances( ...@@ -72,7 +73,7 @@ void add_device_instances(
C1DEElementwiseOperation, C1DEElementwiseOperation,
MaskingSpec>; MaskingSpec>;
add_device_operation_instances( add_device_operation_instances(
instances, DeviceOperationInstanceCreator<DeviceOp>::create_device_instances()); instances, DeviceOperationInstanceCreator<DeviceOp, Arch>::create_device_instances());
} }
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -92,7 +92,8 @@ struct DeviceOperationInstanceCreator<DeviceBatchedGemmSoftmaxGemmPermute<NumDim ...@@ -92,7 +92,8 @@ struct DeviceOperationInstanceCreator<DeviceBatchedGemmSoftmaxGemmPermute<NumDim
C0DEElementwiseOperation, C0DEElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
C1DEElementwiseOperation, C1DEElementwiseOperation,
MaskingSpec>> MaskingSpec>,
ArchitectureEnum::Xdl>
{ {
static auto create_device_instances() static auto create_device_instances()
{ {
......
...@@ -92,7 +92,8 @@ struct DeviceOperationInstanceCreator<DeviceBatchedGemmSoftmaxGemmPermute<NumDim ...@@ -92,7 +92,8 @@ struct DeviceOperationInstanceCreator<DeviceBatchedGemmSoftmaxGemmPermute<NumDim
C0DEElementwiseOperation, C0DEElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
C1DEElementwiseOperation, C1DEElementwiseOperation,
MaskingSpec>> MaskingSpec>,
ArchitectureEnum::Xdl>
{ {
static auto create_device_instances() static auto create_device_instances()
{ {
......
...@@ -135,7 +135,25 @@ int main() ...@@ -135,7 +135,25 @@ int main()
// get device op instances // get device op instances
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
ck::tensor_operation::device::instance::add_device_instances(op_ptrs); ck::tensor_operation::device::instance::add_device_instances<
2,
1,
1,
1,
1,
ADataType,
B0DataType,
B1DataType,
CDataType,
ck::Tuple<D00DataType, D01DataType>,
ck::Tuple<>,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
MaskingSpec,
ck::tensor_operation::device::instance::ArchitectureEnum::Xdl>(op_ptrs);
std::cout << "found " << op_ptrs.size() << " instances" << std::endl; std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
......
cmake_minimum_required(VERSION 3.15)
project(ck_app)
add_compile_options(-std=c++17)
find_package(composable_kernel 1.0.0 COMPONENTS device_operations)
find_package(hip REQUIRED PATHS /opt/rocm)
message(STATUS "Build with HIP ${hip_VERSION}")
# add all example subdir
file(GLOB dir_list LIST_DIRECTORIES true *)
FOREACH(subdir ${dir_list})
IF(IS_DIRECTORY "${subdir}" AND (NOT "${subdir}" MATCHES "build"))
add_subdirectory(${subdir})
ENDIF()
ENDFOREACH()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment