Commit ba251e4a authored by Umang Yadav's avatar Umang Yadav
Browse files

Formatting and put find_package(hip) behind JIT_LIB flag

parent 000c8bcf
......@@ -108,15 +108,6 @@ if(GPU_TARGETS)
else()
message("Building CK for the following targets: ${AMDGPU_TARGETS}")
endif()
find_package(hip)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# SWDEV-413293 and https://reviews.llvm.org/D155213
math(EXPR hip_VERSION_FLAT "(${hip_VERSION_MAJOR} * 1000 + ${hip_VERSION_MINOR}) * 100000 + ${hip_VERSION_PATCH}")
message("hip_version_flat=${hip_VERSION_FLAT}")
if(${hip_VERSION_FLAT} GREATER 500723302)
message("Adding the fno-offload-uniform-block compiler flag")
add_compile_options(-fno-offload-uniform-block)
endif()
option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF)
option(USE_OPT_NAVI3X, "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF)
......@@ -147,6 +138,16 @@ message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}")
option(CK_BUILD_JIT_LIB, "Only build the CK JIT Helper Library" OFF)
if (NOT CK_BUILD_JIT_LIB)
find_package(hip)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# SWDEV-413293 and https://reviews.llvm.org/D155213
math(EXPR hip_VERSION_FLAT "(${hip_VERSION_MAJOR} * 1000 + ${hip_VERSION_MINOR}) * 100000 + ${hip_VERSION_PATCH}")
message("hip_version_flat=${hip_VERSION_FLAT}")
if(${hip_VERSION_FLAT} GREATER 500723302)
message("Adding the fno-offload-uniform-block compiler flag")
add_compile_options(-fno-offload-uniform-block)
endif()
option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF)
option(USE_OPT_NAVI3X, "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -99,7 +99,8 @@ struct Problem
static const std::size_t B1BlockLdsAddExtraN_idx = 52;
static const std::size_t CShuffleMXdlPerWavePerShuffle_idx = 53;
static const std::size_t CShuffleNXdlPerWavePerShuffle_idx = 54;
static const std::size_t CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl_idx = 55;
static const std::size_t
CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl_idx = 55;
static const std::size_t CBlockTransferScalarPerVector_NWaveNPerXdl_idx = 56;
static const std::size_t MaskOutUpperTriangle_idx = 57;
};
......
......@@ -80,7 +80,8 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
const std::size_t k_per_block = std::stoi(k_per_block_str);
const std::size_t n1_per_block = std::stoi(n1_per_block_str);
const std::size_t grid_size = GetGridSize(M, O, m_per_block, n1_per_block);
params[GEMMSpecialization_idx] = GetGemmSpec(M, N, K, O, m_per_block, n_per_block, k_per_block, n1_per_block);
params[GEMMSpecialization_idx] =
GetGemmSpec(M, N, K, O, m_per_block, n_per_block, k_per_block, n1_per_block);
std::string str = std::accumulate(
params.begin() + 1,
......
......@@ -101,18 +101,21 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
if(ADataType == DataType::Int8 and BDataType == DataType::Int8)
{
// Change CBlockTransfer ScalarPerVector if Ds contains other types
if(EDataType == DataType::Half or std::any_of(
DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Half; }))
if(EDataType == DataType::Half or std::any_of(DsDataType.begin(),
DsDataType.end(),
[](auto t) { return t == DataType::Half; }))
{
params[params.size() - 3] = "8";
}
if(EDataType == DataType::Float or std::any_of(
DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Float; }))
if(EDataType == DataType::Float or std::any_of(DsDataType.begin(),
DsDataType.end(),
[](auto t) { return t == DataType::Float; }))
{
params[params.size() - 3] = "4";
}
if(EDataType == DataType::Int32 or std::any_of(
DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Int32; }))
if(EDataType == DataType::Int32 or std::any_of(DsDataType.begin(),
DsDataType.end(),
[](auto t) { return t == DataType::Int32; }))
{
params[params.size() - 3] = "4";
}
......@@ -141,7 +144,7 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
std::string{},
[](const std::string& a, const std::string& b) { return a.empty() ? b : a + ", " + b; });
str = params.front() + "< " + str + ">";
if (params.back().find("v2") != std::string::npos and K % k_per_block != 0)
if(params.back().find("v2") != std::string::npos and K % k_per_block != 0)
str = "";
return Solution{str, block_size, grid_size};
......@@ -159,7 +162,7 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
for(std::size_t i = 0; i < num_instances; ++i)
{
auto solution = MakeSolution(i, arch);
if (solution.template_str != "")
if(solution.template_str != "")
solutions.push_back(solution);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment