Commit ed305f6b authored by Umang Yadav's avatar Umang Yadav
Browse files

formatting

parent 9f4e3544
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -99,7 +99,8 @@ struct Problem ...@@ -99,7 +99,8 @@ struct Problem
static const std::size_t B1BlockLdsAddExtraN_idx = 52; static const std::size_t B1BlockLdsAddExtraN_idx = 52;
static const std::size_t CShuffleMXdlPerWavePerShuffle_idx = 53; static const std::size_t CShuffleMXdlPerWavePerShuffle_idx = 53;
static const std::size_t CShuffleNXdlPerWavePerShuffle_idx = 54; static const std::size_t CShuffleNXdlPerWavePerShuffle_idx = 54;
static const std::size_t CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl_idx = 55; static const std::size_t
CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl_idx = 55;
static const std::size_t CBlockTransferScalarPerVector_NWaveNPerXdl_idx = 56; static const std::size_t CBlockTransferScalarPerVector_NWaveNPerXdl_idx = 56;
static const std::size_t MaskOutUpperTriangle_idx = 57; static const std::size_t MaskOutUpperTriangle_idx = 57;
}; };
......
...@@ -80,7 +80,8 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const ...@@ -80,7 +80,8 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
const std::size_t k_per_block = std::stoi(k_per_block_str); const std::size_t k_per_block = std::stoi(k_per_block_str);
const std::size_t n1_per_block = std::stoi(n1_per_block_str); const std::size_t n1_per_block = std::stoi(n1_per_block_str);
const std::size_t grid_size = GetGridSize(M, O, m_per_block, n1_per_block); const std::size_t grid_size = GetGridSize(M, O, m_per_block, n1_per_block);
params[GEMMSpecialization_idx] = GetGemmSpec(M, N, K, O, m_per_block, n_per_block, k_per_block, n1_per_block); params[GEMMSpecialization_idx] =
GetGemmSpec(M, N, K, O, m_per_block, n_per_block, k_per_block, n1_per_block);
std::string str = std::accumulate( std::string str = std::accumulate(
params.begin() + 1, params.begin() + 1,
......
...@@ -101,18 +101,21 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const ...@@ -101,18 +101,21 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
if(ADataType == DataType::Int8 and BDataType == DataType::Int8) if(ADataType == DataType::Int8 and BDataType == DataType::Int8)
{ {
// Change CBlockTransfer ScalarPerVector if Ds contains other types // Change CBlockTransfer ScalarPerVector if Ds contains other types
if(EDataType == DataType::Half or std::any_of( if(EDataType == DataType::Half or std::any_of(DsDataType.begin(),
DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Half; })) DsDataType.end(),
[](auto t) { return t == DataType::Half; }))
{ {
params[params.size() - 3] = "8"; params[params.size() - 3] = "8";
} }
if(EDataType == DataType::Float or std::any_of( if(EDataType == DataType::Float or std::any_of(DsDataType.begin(),
DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Float; })) DsDataType.end(),
[](auto t) { return t == DataType::Float; }))
{ {
params[params.size() - 3] = "4"; params[params.size() - 3] = "4";
} }
if(EDataType == DataType::Int32 or std::any_of( if(EDataType == DataType::Int32 or std::any_of(DsDataType.begin(),
DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Int32; })) DsDataType.end(),
[](auto t) { return t == DataType::Int32; }))
{ {
params[params.size() - 3] = "4"; params[params.size() - 3] = "4";
} }
...@@ -141,7 +144,7 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const ...@@ -141,7 +144,7 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
std::string{}, std::string{},
[](const std::string& a, const std::string& b) { return a.empty() ? b : a + ", " + b; }); [](const std::string& a, const std::string& b) { return a.empty() ? b : a + ", " + b; });
str = params.front() + "< " + str + ">"; str = params.front() + "< " + str + ">";
if (params.back().find("v2") != std::string::npos and K % k_per_block != 0) if(params.back().find("v2") != std::string::npos and K % k_per_block != 0)
str = ""; str = "";
return Solution{str, block_size, grid_size}; return Solution{str, block_size, grid_size};
...@@ -159,7 +162,7 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch) const ...@@ -159,7 +162,7 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
for(std::size_t i = 0; i < num_instances; ++i) for(std::size_t i = 0; i < num_instances; ++i)
{ {
auto solution = MakeSolution(i, arch); auto solution = MakeSolution(i, arch);
if (solution.template_str != "") if(solution.template_str != "")
solutions.push_back(solution); solutions.push_back(solution);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment