Commit 14822d71 authored by Jing Zhang's avatar Jing Zhang
Browse files

merge

parents 5b02dfaf 80560ef2
......@@ -3,7 +3,7 @@ repos:
hooks:
- id: clang-format
name: clang-format
entry: clang-format-10 -i --style=file
entry: clang-format-12 -i --style=file
language: system
types_or: [c++, inc]
- id: copyright-year-checker
......
cmake_minimum_required(VERSION 3.14)
set(version 1.1.0)
# Check support for CUDA/HIP in Cmake
project(composable_kernel)
project(composable_kernel VERSION ${version})
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
if (DTYPES)
add_definitions(-DDTYPES)
if (DTYPES MATCHES "int8")
add_definitions(-D__int8__)
add_definitions(-DCK_ENABLE_INT8)
set(CK_ENABLE_INT8 "ON")
endif()
if (DTYPES MATCHES "fp8")
add_definitions(-D__fp8__)
add_definitions(-DCK_ENABLE_FP8)
set(CK_ENABLE_FP8 "ON")
endif()
if (DTYPES MATCHES "fp16")
add_definitions(-D__fp16__)
add_definitions(-DCK_ENABLE_FP16)
set(CK_ENABLE_FP16 "ON")
endif()
if (DTYPES MATCHES "fp32")
add_definitions(-D__fp32__)
add_definitions(-DCK_ENABLE_FP32)
set(CK_ENABLE_FP32 "ON")
endif()
if (DTYPES MATCHES "fp64")
add_definitions(-D__fp64__)
add_definitions(-DCK_ENABLE_FP64)
set(CK_ENABLE_FP64 "ON")
endif()
if (DTYPES MATCHES "bf16")
add_definitions(-D__bf16__)
add_definitions(-DCK_ENABLE_BF16)
set(CK_ENABLE_BF16 "ON")
endif()
message("DTYPES macro set to ${DTYPES}")
else()
add_definitions(-D__int8__ -D__fp8__ -D__fp16__ -D__fp32__ -D__fp64__ -D__bf16__)
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
set(CK_ENABLE_ALL_DTYPES "ON")
endif()
if(DL_KERNELS)
add_definitions(-DDL_KERNELS)
set(CK_ENABLE_DL_KERNELS "ON")
endif()
if(INSTANCES_ONLY)
add_definitions(-DINSTANCES_ONLY)
set(CK_ENABLE_INSTANCES_ONLY "ON")
endif()
# CK config file to record supported datatypes, etc.
configure_file("${PROJECT_SOURCE_DIR}/include/ck/config.h.in" "${PROJECT_BINARY_DIR}/include/ck/config.h")
# CK version file to record release version as well as git commit hash
find_package(Git REQUIRED)
execute_process(COMMAND "${GIT_EXECUTABLE}" rev-parse HEAD OUTPUT_VARIABLE COMMIT_ID OUTPUT_STRIP_TRAILING_WHITESPACE)
configure_file("${PROJECT_SOURCE_DIR}/include/ck/version.h.in" "${PROJECT_BINARY_DIR}/include/ck/version.h")
enable_testing()
set(ROCM_SYMLINK_LIBS OFF)
......@@ -50,8 +68,10 @@ include(ROCMInstallSymlinks)
include(ROCMCreatePackage)
include(CheckCXXCompilerFlag)
include(ROCMCheckTargetIds)
rocm_setup_version(VERSION 0.2.0)
include(TargetFlags)
rocm_setup_version(VERSION ${version})
list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip)
message("GPU_TARGETS= ${GPU_TARGETS}")
......@@ -315,13 +335,14 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin)
# set CK project include directories
include_directories(BEFORE
${PROJECT_BINARY_DIR}/include
${PROJECT_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/library/include
${HIP_INCLUDE_DIRS}
)
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV)
add_compile_options(-Werror)
......@@ -409,7 +430,6 @@ endif()
#Create an interface target for the include only files and call it "composablekernels"
include(CMakePackageConfigHelpers)
set(version 1.0.0)
write_basic_package_version_file(
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake"
VERSION "${version}"
......@@ -428,6 +448,13 @@ rocm_install(FILES
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
)
# Install CK version and configuration files
install(FILES
${PROJECT_BINARY_DIR}/include/ck/version.h
${PROJECT_BINARY_DIR}/include/ck/config.h
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck/
)
set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE")
set(CPACK_RPM_PACKAGE_LICENSE "MIT")
......
......@@ -63,7 +63,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
nano \
zlib1g-dev \
openssh-server \
clang-format-10 \
clang-format-12 \
kmod && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
......
......@@ -689,7 +689,7 @@ pipeline {
-o -iname \'*.cpp.in\' \
-o -iname \'*.cl\' \
| grep -v 'build/' \
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-10 -style=file {} | diff - {}\'"
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\'"
}
steps{
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true)
......
......@@ -191,6 +191,12 @@ int main(int argc, char* argv[])
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
......
......@@ -187,6 +187,12 @@ int main(int argc, char* argv[])
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
......
add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp)
target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_operations)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <vector>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp"
using F8 = ck::f8_t;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
using ADataType = F8;
using BDataType = F16;
using CDataType = F16;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main(int argc, char* argv[])
{
// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
ck::index_t KBatch = 1;
if(argc == 1)
{
// use default case
}
else if(argc == 8)
{
M = std::stoi(argv[1]);
N = std::stoi(argv[2]);
K = std::stoi(argv[3]);
StrideA = std::stoi(argv[4]);
StrideB = std::stoi(argv[5]);
StrideC = std::stoi(argv[6]);
KBatch = std::stoi(argv[7]);
}
else
{
printf("arg1 to 7: M, N, K, StrideA, StrideB, StrideC, KBatch\n");
exit(0);
}
auto f_matrix_space_size =
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout);
if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
{
return (nRow - 1) * stride + nCol;
}
else
{
return (nCol - 1) * stride + nRow;
}
};
SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{}));
SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{}));
SimpleDeviceMem c_device_buf(sizeof(CDataType) * f_matrix_space_size(M, N, StrideC, CLayout{}));
using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK<
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
c_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
KBatch);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
// run the best intance
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
c_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
KBatch);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}
add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_operations)
......@@ -20,7 +20,7 @@ using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddBias = ck::tensor_operation::element_wise::AddBias;
using Add = ck::tensor_operation::element_wise::Add;
using ADataType = F16;
using BDataType = F16;
......@@ -36,7 +36,7 @@ using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = AddBias;
using CDEElementOp = Add;
struct SimpleDeviceMem
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
#include <vector>
#include <random>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp"
using F8 = ck::f8_t;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F8;
using DsDataType = ck::Tuple<>;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main()
{
std::vector<int> Ms, Ns, Ks, StrideAs, StrideBs, StrideEs;
int sum_of_m = 0;
Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
int group_count = Ms.size();
for(int i = 0; i < group_count; ++i)
{
Ns.push_back(768);
Ks.push_back(4608);
StrideAs.push_back(std::is_same<Row, ALayout>::value ? Ks[i] : Ms[i]);
StrideBs.push_back(std::is_same<Row, BLayout>::value ? Ns[i] : Ks[i]);
StrideEs.push_back(std::is_same<Row, ELayout>::value ? Ns[i] : Ms[i]);
sum_of_m += Ms[i];
}
auto f_matrix_space_size =
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout);
if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
{
return (nRow - 1) * stride + nCol;
}
else
{
return (nCol - 1) * stride + nRow;
}
};
std::vector<SimpleDeviceMem> a_dev_bufs, b_dev_bufs, e_dev_bufs;
a_dev_bufs.reserve(group_count);
b_dev_bufs.reserve(group_count);
e_dev_bufs.reserve(group_count);
std::vector<void*> p_e;
p_e.reserve(group_count);
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
gemm_descs.reserve(group_count);
std::vector<ck::tensor_operation::device::GroupedGemmKernelArgument<1>>
grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count);
for(int i = 0; i < group_count; ++i)
{
a_dev_bufs.emplace_back(sizeof(ADataType) *
f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{}));
b_dev_bufs.emplace_back(sizeof(BDataType) *
f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{}));
e_dev_bufs.emplace_back(sizeof(EDataType) *
f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{}));
gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 1, StrideBs[i], 1, {0}});
p_e.push_back(e_dev_bufs[i].GetDeviceBuffer());
grouped_gemm_kernel_args_.push_back({a_dev_bufs[i].GetDeviceBuffer(),
b_dev_bufs[i].GetDeviceBuffer(),
{},
e_dev_bufs[i].GetDeviceBuffer(),
Ms[i],
Ns[i],
Ks[i],
StrideAs[i],
StrideBs[i],
{},
StrideEs[i]});
}
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementOp,
BElementOp,
CDEElementOp>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto cde_element_op = CDEElementOp{};
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
std::vector<const void*> p_a = {}, p_b = {};
std::vector<std::array<const void*, 0>> p_ds = {};
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
SimpleDeviceMem grouped_gemm_kernel_args_dev(
op_ptr->GetDeviceKernelArgSize(argument_ptr.get()));
SimpleDeviceMem grouped_gemm_workspace_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get()));
std::string op_name = op_ptr->GetTypeString();
hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(),
op_ptr->GetDeviceKernelArgSize(argument_ptr.get()),
hipMemcpyHostToDevice));
op_ptr->SetWorkSpacePointer(argument_ptr.get(),
grouped_gemm_workspace_dev.GetDeviceBuffer());
op_ptr->SetDeviceKernelArgs(argument_ptr.get(),
grouped_gemm_kernel_args_dev.GetDeviceBuffer());
op_ptr->SetKBatch(argument_ptr.get(), 16);
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = 0, num_btype = 0;
for(std::size_t j = 0; j < gemm_descs.size(); ++j)
{
flop += std::size_t(2) * Ms[j] * Ns[j] * Ks[j];
num_btype += sizeof(ADataType) * Ms[j] * Ks[j] + sizeof(BDataType) * Ks[j] * Ns[j] +
sizeof(EDataType) * Ms[j] * Ns[j];
}
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
#include <vector>
#include <random>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp"
using I8 = int8_t;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = I8;
using DsDataType = ck::Tuple<>;
using EDataType = F16;
using ALayout = Row;
using BLayout = Row;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main()
{
std::vector<int> Ms, Ns, Ks, StrideAs, StrideBs, StrideEs;
int sum_of_m = 0;
Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
int group_count = Ms.size();
for(int i = 0; i < group_count; ++i)
{
Ns.push_back(768);
Ks.push_back(4608);
StrideAs.push_back(std::is_same<Row, ALayout>::value ? Ks[i] : Ms[i]);
StrideBs.push_back(std::is_same<Row, BLayout>::value ? Ns[i] : Ks[i]);
StrideEs.push_back(std::is_same<Row, ELayout>::value ? Ns[i] : Ms[i]);
sum_of_m += Ms[i];
}
auto f_matrix_space_size =
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout);
if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
{
return (nRow - 1) * stride + nCol;
}
else
{
return (nCol - 1) * stride + nRow;
}
};
std::vector<SimpleDeviceMem> a_dev_bufs, b_dev_bufs, e_dev_bufs;
a_dev_bufs.reserve(group_count);
b_dev_bufs.reserve(group_count);
e_dev_bufs.reserve(group_count);
std::vector<void*> p_e;
p_e.reserve(group_count);
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
gemm_descs.reserve(group_count);
std::vector<ck::tensor_operation::device::GroupedGemmKernelArgument<1>>
grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count);
for(int i = 0; i < group_count; ++i)
{
a_dev_bufs.emplace_back(sizeof(ADataType) *
f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{}));
b_dev_bufs.emplace_back(sizeof(BDataType) *
f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{}));
e_dev_bufs.emplace_back(sizeof(EDataType) *
f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{}));
gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 1, StrideBs[i], 1, {0}});
p_e.push_back(e_dev_bufs[i].GetDeviceBuffer());
grouped_gemm_kernel_args_.push_back({a_dev_bufs[i].GetDeviceBuffer(),
b_dev_bufs[i].GetDeviceBuffer(),
{},
e_dev_bufs[i].GetDeviceBuffer(),
Ms[i],
Ns[i],
Ks[i],
StrideAs[i],
StrideBs[i],
{},
StrideEs[i]});
}
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementOp,
BElementOp,
CDEElementOp>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto cde_element_op = CDEElementOp{};
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
std::vector<const void*> p_a = {}, p_b = {};
std::vector<std::array<const void*, 0>> p_ds = {};
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
SimpleDeviceMem grouped_gemm_kernel_args_dev(
op_ptr->GetDeviceKernelArgSize(argument_ptr.get()));
SimpleDeviceMem grouped_gemm_workspace_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get()));
std::string op_name = op_ptr->GetTypeString();
hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(),
op_ptr->GetDeviceKernelArgSize(argument_ptr.get()),
hipMemcpyHostToDevice));
op_ptr->SetWorkSpacePointer(argument_ptr.get(),
grouped_gemm_workspace_dev.GetDeviceBuffer());
op_ptr->SetDeviceKernelArgs(argument_ptr.get(),
grouped_gemm_kernel_args_dev.GetDeviceBuffer());
op_ptr->SetKBatch(argument_ptr.get(), 32);
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = 0, num_btype = 0;
for(std::size_t j = 0; j < gemm_descs.size(); ++j)
{
flop += std::size_t(2) * Ms[j] * Ns[j] * Ks[j];
num_btype += sizeof(ADataType) * Ms[j] * Ks[j] + sizeof(BDataType) * Ks[j] * Ns[j] +
sizeof(EDataType) * Ms[j] * Ns[j];
}
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return 0;
}
add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_operations)
add_executable(client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_operations)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
#include <vector>
#include <random>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_bias.hpp"
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using ADataType = F16;
using BDataType = F16;
using D0DataType = F32;
using DsDataType = ck::Tuple<D0DataType>;
using EDataType = F32;
using ALayout = Row;
using BLayout = Row;
using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = Add;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main()
{
std::vector<int> Ms, Ns, Ks, StrideAs, StrideBs, StrideEs;
int sum_of_m = 0;
Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
int group_count = Ms.size();
for(int i = 0; i < group_count; ++i)
{
Ns.push_back(768);
Ks.push_back(4608);
StrideAs.push_back(std::is_same<Row, ALayout>::value ? Ks[i] : Ms[i]);
StrideBs.push_back(std::is_same<Row, BLayout>::value ? Ns[i] : Ks[i]);
StrideEs.push_back(std::is_same<Row, ELayout>::value ? Ns[i] : Ms[i]);
sum_of_m += Ms[i];
}
auto f_matrix_space_size =
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout);
if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
{
return (nRow - 1) * stride + nCol;
}
else
{
return (nCol - 1) * stride + nRow;
}
};
std::vector<SimpleDeviceMem> a_dev_bufs, b_dev_bufs, d0_dev_bufs, e_dev_bufs;
a_dev_bufs.reserve(group_count);
b_dev_bufs.reserve(group_count);
d0_dev_bufs.reserve(group_count);
e_dev_bufs.reserve(group_count);
std::vector<void*> p_e;
p_e.reserve(group_count);
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
gemm_descs.reserve(group_count);
std::vector<ck::tensor_operation::device::GroupedGemmKernelArgument<1>>
grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count);
for(int i = 0; i < group_count; ++i)
{
a_dev_bufs.emplace_back(sizeof(ADataType) *
f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{}));
b_dev_bufs.emplace_back(sizeof(BDataType) *
f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{}));
d0_dev_bufs.emplace_back(sizeof(D0DataType) *
f_matrix_space_size(Ms[i], Ns[i], 0, D0Layout{}));
e_dev_bufs.emplace_back(sizeof(EDataType) *
f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{}));
gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 1, StrideBs[i], 1, {0}});
p_e.push_back(e_dev_bufs[i].GetDeviceBuffer());
grouped_gemm_kernel_args_.push_back(
{a_dev_bufs[i].GetDeviceBuffer(),
b_dev_bufs[i].GetDeviceBuffer(),
std::array<const void*, 1>{d0_dev_bufs[i].GetDeviceBuffer()},
e_dev_bufs[i].GetDeviceBuffer(),
Ms[i],
Ns[i],
Ks[i],
StrideAs[i],
StrideBs[i],
std::array<ck::index_t, 1>{0},
StrideEs[i]});
}
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementOp,
BElementOp,
CDEElementOp>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto cde_element_op = CDEElementOp{};
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
std::vector<const void*> p_a = {}, p_b = {};
std::vector<std::array<const void*, 1>> p_ds = {};
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
SimpleDeviceMem grouped_gemm_kernel_args_dev(
op_ptr->GetDeviceKernelArgSize(argument_ptr.get()));
SimpleDeviceMem grouped_gemm_workspace_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get()));
std::string op_name = op_ptr->GetTypeString();
hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(),
op_ptr->GetDeviceKernelArgSize(argument_ptr.get()),
hipMemcpyHostToDevice));
op_ptr->SetWorkSpacePointer(argument_ptr.get(),
grouped_gemm_workspace_dev.GetDeviceBuffer());
op_ptr->SetDeviceKernelArgs(argument_ptr.get(),
grouped_gemm_kernel_args_dev.GetDeviceBuffer());
op_ptr->SetKBatch(argument_ptr.get(), 2);
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = 0, num_btype = 0;
for(std::size_t j = 0; j < gemm_descs.size(); ++j)
{
flop += std::size_t(2) * Ms[j] * Ns[j] * Ks[j];
num_btype += sizeof(ADataType) * Ms[j] * Ks[j] + sizeof(BDataType) * Ks[j] * Ns[j] +
sizeof(EDataType) * Ms[j] * Ns[j];
}
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
#include <vector>
#include <random>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp"
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using DsDataType = ck::Tuple<>;
using EDataType = F16;
using ALayout = Row;
using BLayout = Row;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main()
{
std::vector<int> Ms, Ns, Ks, StrideAs, StrideBs, StrideEs;
int sum_of_m = 0;
// Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
Ms = {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0};
int group_count = Ms.size();
for(int i = 0; i < group_count; ++i)
{
Ns.push_back(768);
Ks.push_back(4608);
StrideAs.push_back(std::is_same<Row, ALayout>::value ? Ks[i] : Ms[i]);
StrideBs.push_back(std::is_same<Row, BLayout>::value ? Ns[i] : Ks[i]);
StrideEs.push_back(std::is_same<Row, ELayout>::value ? Ns[i] : Ms[i]);
sum_of_m += Ms[i];
}
auto f_matrix_space_size =
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout);
if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
{
return (nRow - 1) * stride + nCol;
}
else
{
return (nCol - 1) * stride + nRow;
}
};
std::vector<SimpleDeviceMem> a_dev_bufs, b_dev_bufs, e_dev_bufs;
a_dev_bufs.reserve(group_count);
b_dev_bufs.reserve(group_count);
e_dev_bufs.reserve(group_count);
std::vector<void*> p_e;
p_e.reserve(group_count);
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
gemm_descs.reserve(group_count);
std::vector<ck::tensor_operation::device::GroupedGemmKernelArgument<1>>
grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count);
for(int i = 0; i < group_count; ++i)
{
a_dev_bufs.emplace_back(sizeof(ADataType) *
f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{}));
b_dev_bufs.emplace_back(sizeof(BDataType) *
f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{}));
e_dev_bufs.emplace_back(sizeof(EDataType) *
f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{}));
gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 1, StrideBs[i], 1, {0}});
p_e.push_back(e_dev_bufs[i].GetDeviceBuffer());
grouped_gemm_kernel_args_.push_back({a_dev_bufs[i].GetDeviceBuffer(),
b_dev_bufs[i].GetDeviceBuffer(),
{},
e_dev_bufs[i].GetDeviceBuffer(),
Ms[i],
Ns[i],
Ks[i],
StrideAs[i],
StrideBs[i],
{},
StrideEs[i]});
}
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementOp,
BElementOp,
CDEElementOp>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto cde_element_op = CDEElementOp{};
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
std::vector<const void*> p_a = {}, p_b = {};
std::vector<std::array<const void*, 0>> p_ds = {};
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
SimpleDeviceMem grouped_gemm_kernel_args_dev(
op_ptr->GetDeviceKernelArgSize(argument_ptr.get()));
SimpleDeviceMem grouped_gemm_workspace_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get()));
std::string op_name = op_ptr->GetTypeString();
hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(),
op_ptr->GetDeviceKernelArgSize(argument_ptr.get()),
hipMemcpyHostToDevice));
op_ptr->SetWorkSpacePointer(argument_ptr.get(),
grouped_gemm_workspace_dev.GetDeviceBuffer());
op_ptr->SetDeviceKernelArgs(argument_ptr.get(),
grouped_gemm_kernel_args_dev.GetDeviceBuffer());
op_ptr->SetKBatch(argument_ptr.get(), 32);
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = 0, num_btype = 0;
for(std::size_t j = 0; j < gemm_descs.size(); ++j)
{
flop += std::size_t(2) * Ms[j] * Ns[j] * Ks[j];
num_btype += sizeof(ADataType) * Ms[j] * Ks[j] + sizeof(BDataType) * Ks[j] * Ns[j] +
sizeof(EDataType) * Ms[j] * Ns[j];
}
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return 0;
}
......@@ -5,29 +5,50 @@ add_compile_options(-std=c++17)
if (DTYPES)
add_definitions(-DDTYPES)
if (DTYPES MATCHES "int8")
add_definitions(-D__int8__)
add_definitions(-DCK_ENABLE_INT8)
if(NOT DEFINED ${CK_ENABLE_INT8})
set(CK_ENABLE_INT8 "ON")
endif()
endif()
if (DTYPES MATCHES "fp8")
add_definitions(-D__fp8__)
add_definitions(-DCK_ENABLE_FP8)
if(NOT DEFINED ${CK_ENABLE_FP8})
set(CK_ENABLE_FP8 "ON")
endif()
endif()
if (DTYPES MATCHES "fp16")
add_definitions(-D__fp16__)
add_definitions(-DCK_ENABLE_FP16)
if(NOT DEFINED ${CK_ENABLE_FP16})
set(CK_ENABLE_FP16 "ON")
endif()
endif()
if (DTYPES MATCHES "fp32")
add_definitions(-D__fp32__)
add_definitions(-DCK_ENABLE_FP32)
if(NOT DEFINED ${CK_ENABLE_FP32})
set(CK_ENABLE_FP32 "ON")
endif()
endif()
if (DTYPES MATCHES "fp64")
add_definitions(-D__fp64__)
add_definitions(-DCK_ENABLE_FP64)
if(NOT DEFINED ${CK_ENABLE_FP64})
set(CK_ENABLE_FP64 "ON")
endif()
endif()
if (DTYPES MATCHES "bf16")
add_definitions(-D__bf16__)
add_definitions(-DCK_ENABLE_BF16)
if(NOT DEFINED ${CK_ENABLE_BF16})
set(CK_ENABLE_BF16 "ON")
endif()
endif()
message("DTYPES macro set to ${DTYPES}")
else()
add_definitions(-D__int8__ -D__fp8__ -D__fp16__ -D__fp32__ -D__fp64__ -D__bf16__)
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
if(NOT DEFINED ${CK_ENABLE_ALL_DTYPES})
set(CK_ENABLE_ALL_DTYPES "ON")
endif()
endif()
find_package(composable_kernel 1.0.0 COMPONENTS device_operations)
find_package(composable_kernel COMPONENTS device_operations)
find_package(hip REQUIRED PATHS /opt/rocm)
message(STATUS "Build with HIP ${hip_VERSION}")
......
......@@ -40,6 +40,9 @@ endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_bf16)
add_example_executable(example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_bf16_rtn)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment