Commit ed305f6b authored by Umang Yadav's avatar Umang Yadav
Browse files

formatting

parent 9f4e3544
...@@ -822,7 +822,7 @@ struct BlockToCTileMap_GemmStreamK ...@@ -822,7 +822,7 @@ struct BlockToCTileMap_GemmStreamK
dp_num_blocks = num_tiles; // all tile to be dp block dp_num_blocks = num_tiles; // all tile to be dp block
dp_start_block_idx = 0; dp_start_block_idx = 0;
sk_total_iters = 0; // clear this tiles sk_total_iters = 0; // clear this tiles
} }
else else
{ {
......
...@@ -72,7 +72,7 @@ struct MagicDivision ...@@ -72,7 +72,7 @@ struct MagicDivision
// integral_constant<uint32_t, .> // integral_constant<uint32_t, .>
template <uint32_t Divisor> template <uint32_t Divisor>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
CalculateMagicNumbers(integral_constant<uint32_t, Divisor>) CalculateMagicNumbers(integral_constant<uint32_t, Divisor>)
{ {
constexpr auto tmp = CalculateMagicNumbers(uint32_t{Divisor}); constexpr auto tmp = CalculateMagicNumbers(uint32_t{Divisor});
...@@ -85,7 +85,7 @@ struct MagicDivision ...@@ -85,7 +85,7 @@ struct MagicDivision
template <uint32_t Divisor> template <uint32_t Divisor>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
CalculateMagicMultiplier(integral_constant<uint32_t, Divisor>) CalculateMagicMultiplier(integral_constant<uint32_t, Divisor>)
{ {
constexpr uint32_t multiplier = CalculateMagicMultiplier(uint32_t{Divisor}); constexpr uint32_t multiplier = CalculateMagicMultiplier(uint32_t{Divisor});
...@@ -94,7 +94,7 @@ struct MagicDivision ...@@ -94,7 +94,7 @@ struct MagicDivision
template <uint32_t Divisor> template <uint32_t Divisor>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
CalculateMagicShift(integral_constant<uint32_t, Divisor>) CalculateMagicShift(integral_constant<uint32_t, Divisor>)
{ {
constexpr uint32_t shift = CalculateMagicShift(uint32_t{Divisor}); constexpr uint32_t shift = CalculateMagicShift(uint32_t{Divisor});
...@@ -104,21 +104,21 @@ struct MagicDivision ...@@ -104,21 +104,21 @@ struct MagicDivision
// integral_constant<int32_t, .> // integral_constant<int32_t, .>
template <int32_t Divisor> template <int32_t Divisor>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
CalculateMagicNumbers(integral_constant<int32_t, Divisor>) CalculateMagicNumbers(integral_constant<int32_t, Divisor>)
{ {
return CalculateMagicNumbers(integral_constant<uint32_t, Divisor>{}); return CalculateMagicNumbers(integral_constant<uint32_t, Divisor>{});
} }
template <int32_t Divisor> template <int32_t Divisor>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
CalculateMagicMultiplier(integral_constant<int32_t, Divisor>) CalculateMagicMultiplier(integral_constant<int32_t, Divisor>)
{ {
return CalculateMagicMultiplier(integral_constant<uint32_t, Divisor>{}); return CalculateMagicMultiplier(integral_constant<uint32_t, Divisor>{});
} }
template <int32_t Divisor> template <int32_t Divisor>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
CalculateMagicShift(integral_constant<int32_t, Divisor>) CalculateMagicShift(integral_constant<int32_t, Divisor>)
{ {
return CalculateMagicShift(integral_constant<uint32_t, Divisor>{}); return CalculateMagicShift(integral_constant<uint32_t, Divisor>{});
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -17,23 +17,23 @@ namespace device_batched_gemm_softmax_gemm { ...@@ -17,23 +17,23 @@ namespace device_batched_gemm_softmax_gemm {
struct Problem struct Problem
{ {
std::size_t M = 0; std::size_t M = 0;
std::size_t N = 0; std::size_t N = 0;
std::size_t K = 0; std::size_t K = 0;
std::size_t O = 0; std::size_t O = 0;
bool TransA = false; bool TransA = false;
bool TransB = false; bool TransB = false;
bool TransB1 = false; bool TransB1 = false;
bool TransC = false; bool TransC = false;
DataType ADataType = DataType::Half; DataType ADataType = DataType::Half;
DataType BDataType = DataType::Half; DataType BDataType = DataType::Half;
DataType B1DataType = DataType::Half; DataType B1DataType = DataType::Half;
DataType CDataType = DataType::Half; DataType CDataType = DataType::Half;
std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough"; std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough"; std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string B1ElementOp = "ck::tensor_operation::element_wise::PassThrough"; std::string B1ElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string CElementOp = "ck::tensor_operation::element_wise::PassThrough"; std::string CElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string AccElementOp = "ck::tensor_operation::element_wise::Scale"; std::string AccElementOp = "ck::tensor_operation::element_wise::Scale";
std::string GetIncludeHeader() const; std::string GetIncludeHeader() const;
...@@ -44,64 +44,65 @@ struct Problem ...@@ -44,64 +44,65 @@ struct Problem
Solution MakeSolution(std::size_t idx, const std::string& arch) const; Solution MakeSolution(std::size_t idx, const std::string& arch) const;
static const std::size_t DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle_idx = 0; static const std::size_t DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle_idx = 0;
static const std::size_t ALayout_idx = 1; static const std::size_t ALayout_idx = 1;
static const std::size_t B0Layout_idx = 2; static const std::size_t B0Layout_idx = 2;
static const std::size_t B1Layout_idx = 3; static const std::size_t B1Layout_idx = 3;
static const std::size_t CLayout_idx = 4; static const std::size_t CLayout_idx = 4;
static const std::size_t ADataType_idx = 5; static const std::size_t ADataType_idx = 5;
static const std::size_t B0DataType_idx = 6; static const std::size_t B0DataType_idx = 6;
static const std::size_t B1DataType_idx = 7; static const std::size_t B1DataType_idx = 7;
static const std::size_t CDataType_idx = 8; static const std::size_t CDataType_idx = 8;
static const std::size_t AccDataType_idx = 9; static const std::size_t AccDataType_idx = 9;
static const std::size_t CShuffleDataType_idx = 10; static const std::size_t CShuffleDataType_idx = 10;
static const std::size_t AElementwiseOperation_idx = 11; static const std::size_t AElementwiseOperation_idx = 11;
static const std::size_t B0ElementwiseOperation_idx = 12; static const std::size_t B0ElementwiseOperation_idx = 12;
static const std::size_t Acc0ElementwiseOperation_idx = 13; static const std::size_t Acc0ElementwiseOperation_idx = 13;
static const std::size_t B1ElementwiseOperation_idx = 14; static const std::size_t B1ElementwiseOperation_idx = 14;
static const std::size_t CElementwiseOperation_idx = 15; static const std::size_t CElementwiseOperation_idx = 15;
static const std::size_t GEMMSpecialization_idx = 16; static const std::size_t GEMMSpecialization_idx = 16;
static const std::size_t NumGemmKPrefetchStage_idx = 17; static const std::size_t NumGemmKPrefetchStage_idx = 17;
static const std::size_t BlockSize_idx = 18; static const std::size_t BlockSize_idx = 18;
static const std::size_t Gemm01MPerBlock_idx = 19; static const std::size_t Gemm01MPerBlock_idx = 19;
static const std::size_t Gemm0NPerBlock_idx = 20; static const std::size_t Gemm0NPerBlock_idx = 20;
static const std::size_t Gemm0KPerBlock_idx = 21; static const std::size_t Gemm0KPerBlock_idx = 21;
static const std::size_t Gemm1NPerBlock_idx = 22; static const std::size_t Gemm1NPerBlock_idx = 22;
static const std::size_t Gemm1KPerBlock_idx = 23; static const std::size_t Gemm1KPerBlock_idx = 23;
static const std::size_t AK1_idx = 24; static const std::size_t AK1_idx = 24;
static const std::size_t BK1_idx = 25; static const std::size_t BK1_idx = 25;
static const std::size_t B1K1_idx = 26; static const std::size_t B1K1_idx = 26;
static const std::size_t MPerXDL_idx = 27; static const std::size_t MPerXDL_idx = 27;
static const std::size_t NPerXDL_idx = 28; static const std::size_t NPerXDL_idx = 28;
static const std::size_t Gemm0MXdlPerWave_idx = 29; static const std::size_t Gemm0MXdlPerWave_idx = 29;
static const std::size_t Gemm0NXdlPerWave_idx = 30; static const std::size_t Gemm0NXdlPerWave_idx = 30;
static const std::size_t Gemm1NXdlPerWave_idx = 31; static const std::size_t Gemm1NXdlPerWave_idx = 31;
static const std::size_t ABlockTransferThreadClusterLengths_K0_M_K1_idx = 32; static const std::size_t ABlockTransferThreadClusterLengths_K0_M_K1_idx = 32;
static const std::size_t ABlockTransferThreadClusterArrangeOrder_idx = 33; static const std::size_t ABlockTransferThreadClusterArrangeOrder_idx = 33;
static const std::size_t ABlockTransferSrcAccessOrder_idx = 34; static const std::size_t ABlockTransferSrcAccessOrder_idx = 34;
static const std::size_t ABlockTransferSrcVectorDim_idx = 35; static const std::size_t ABlockTransferSrcVectorDim_idx = 35;
static const std::size_t ABlockTransferSrcScalarPerVector_idx = 36; static const std::size_t ABlockTransferSrcScalarPerVector_idx = 36;
static const std::size_t ABlockTransferDstScalarPerVector_K1_idx = 37; static const std::size_t ABlockTransferDstScalarPerVector_K1_idx = 37;
static const std::size_t ABlockLdsAddExtraM_idx = 38; static const std::size_t ABlockLdsAddExtraM_idx = 38;
static const std::size_t B0BlockTransferThreadClusterLengths_K0_N_K1_idx = 39; static const std::size_t B0BlockTransferThreadClusterLengths_K0_N_K1_idx = 39;
static const std::size_t B0BlockTransferThreadClusterArrangeOrder_idx = 40; static const std::size_t B0BlockTransferThreadClusterArrangeOrder_idx = 40;
static const std::size_t B0BlockTransferSrcAccessOrder_idx = 41; static const std::size_t B0BlockTransferSrcAccessOrder_idx = 41;
static const std::size_t B0BlockTransferSrcVectorDim_idx = 42; static const std::size_t B0BlockTransferSrcVectorDim_idx = 42;
static const std::size_t B0BlockTransferSrcScalarPerVector_idx = 43; static const std::size_t B0BlockTransferSrcScalarPerVector_idx = 43;
static const std::size_t B0BlockTransferDstScalarPerVector_K1_idx = 44; static const std::size_t B0BlockTransferDstScalarPerVector_K1_idx = 44;
static const std::size_t B0BlockLdsAddExtraN_idx = 45; static const std::size_t B0BlockLdsAddExtraN_idx = 45;
static const std::size_t B1BlockTransferThreadClusterLengths_K0_N_K1_idx = 46; static const std::size_t B1BlockTransferThreadClusterLengths_K0_N_K1_idx = 46;
static const std::size_t B1BlockTransferThreadClusterArrangeOrder_idx = 47; static const std::size_t B1BlockTransferThreadClusterArrangeOrder_idx = 47;
static const std::size_t B1BlockTransferSrcAccessOrder_idx = 48; static const std::size_t B1BlockTransferSrcAccessOrder_idx = 48;
static const std::size_t B1BlockTransferSrcVectorDim_idx = 49; static const std::size_t B1BlockTransferSrcVectorDim_idx = 49;
static const std::size_t B1BlockTransferSrcScalarPerVector_idx = 50; static const std::size_t B1BlockTransferSrcScalarPerVector_idx = 50;
static const std::size_t B1BlockTransferDstScalarPerVector_K1_idx = 51; static const std::size_t B1BlockTransferDstScalarPerVector_K1_idx = 51;
static const std::size_t B1BlockLdsAddExtraN_idx = 52; static const std::size_t B1BlockLdsAddExtraN_idx = 52;
static const std::size_t CShuffleMXdlPerWavePerShuffle_idx = 53; static const std::size_t CShuffleMXdlPerWavePerShuffle_idx = 53;
static const std::size_t CShuffleNXdlPerWavePerShuffle_idx = 54; static const std::size_t CShuffleNXdlPerWavePerShuffle_idx = 54;
static const std::size_t CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl_idx = 55; static const std::size_t
static const std::size_t CBlockTransferScalarPerVector_NWaveNPerXdl_idx = 56; CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl_idx = 55;
static const std::size_t MaskOutUpperTriangle_idx = 57; static const std::size_t CBlockTransferScalarPerVector_NWaveNPerXdl_idx = 56;
static const std::size_t MaskOutUpperTriangle_idx = 57;
}; };
} // namespace device_batched_gemm_softmax_gemm } // namespace device_batched_gemm_softmax_gemm
......
...@@ -64,23 +64,24 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const ...@@ -64,23 +64,24 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
std::vector<std::string> params(std::istream_iterator<std::string>{iss}, std::vector<std::string> params(std::istream_iterator<std::string>{iss},
std::istream_iterator<std::string>()); std::istream_iterator<std::string>());
params[AElementwiseOperation_idx] = AElementOp; params[AElementwiseOperation_idx] = AElementOp;
params[B0ElementwiseOperation_idx] = BElementOp; params[B0ElementwiseOperation_idx] = BElementOp;
params[B1ElementwiseOperation_idx] = BElementOp; params[B1ElementwiseOperation_idx] = BElementOp;
params[CElementwiseOperation_idx] = CElementOp; params[CElementwiseOperation_idx] = CElementOp;
params[Acc0ElementwiseOperation_idx] = AccElementOp; params[Acc0ElementwiseOperation_idx] = AccElementOp;
auto block_size_str = params[BlockSize_idx]; auto block_size_str = params[BlockSize_idx];
auto m_per_block_str = params[Gemm01MPerBlock_idx]; auto m_per_block_str = params[Gemm01MPerBlock_idx];
auto n_per_block_str = params[Gemm0NPerBlock_idx]; auto n_per_block_str = params[Gemm0NPerBlock_idx];
auto k_per_block_str = params[Gemm0KPerBlock_idx]; auto k_per_block_str = params[Gemm0KPerBlock_idx];
auto n1_per_block_str = params[Gemm1NPerBlock_idx]; auto n1_per_block_str = params[Gemm1NPerBlock_idx];
const std::size_t block_size = std::stoi(block_size_str); const std::size_t block_size = std::stoi(block_size_str);
const std::size_t m_per_block = std::stoi(m_per_block_str); const std::size_t m_per_block = std::stoi(m_per_block_str);
const std::size_t n_per_block = std::stoi(n_per_block_str); const std::size_t n_per_block = std::stoi(n_per_block_str);
const std::size_t k_per_block = std::stoi(k_per_block_str); const std::size_t k_per_block = std::stoi(k_per_block_str);
const std::size_t n1_per_block = std::stoi(n1_per_block_str); const std::size_t n1_per_block = std::stoi(n1_per_block_str);
const std::size_t grid_size = GetGridSize(M, O, m_per_block, n1_per_block); const std::size_t grid_size = GetGridSize(M, O, m_per_block, n1_per_block);
params[GEMMSpecialization_idx] = GetGemmSpec(M, N, K, O, m_per_block, n_per_block, k_per_block, n1_per_block); params[GEMMSpecialization_idx] =
GetGemmSpec(M, N, K, O, m_per_block, n_per_block, k_per_block, n1_per_block);
std::string str = std::accumulate( std::string str = std::accumulate(
params.begin() + 1, params.begin() + 1,
......
...@@ -101,18 +101,21 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const ...@@ -101,18 +101,21 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
if(ADataType == DataType::Int8 and BDataType == DataType::Int8) if(ADataType == DataType::Int8 and BDataType == DataType::Int8)
{ {
// Change CBlockTransfer ScalarPerVector if Ds contains other types // Change CBlockTransfer ScalarPerVector if Ds contains other types
if(EDataType == DataType::Half or std::any_of( if(EDataType == DataType::Half or std::any_of(DsDataType.begin(),
DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Half; })) DsDataType.end(),
[](auto t) { return t == DataType::Half; }))
{ {
params[params.size() - 3] = "8"; params[params.size() - 3] = "8";
} }
if(EDataType == DataType::Float or std::any_of( if(EDataType == DataType::Float or std::any_of(DsDataType.begin(),
DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Float; })) DsDataType.end(),
[](auto t) { return t == DataType::Float; }))
{ {
params[params.size() - 3] = "4"; params[params.size() - 3] = "4";
} }
if(EDataType == DataType::Int32 or std::any_of( if(EDataType == DataType::Int32 or std::any_of(DsDataType.begin(),
DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Int32; })) DsDataType.end(),
[](auto t) { return t == DataType::Int32; }))
{ {
params[params.size() - 3] = "4"; params[params.size() - 3] = "4";
} }
...@@ -134,14 +137,14 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const ...@@ -134,14 +137,14 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
const std::size_t k_per_block = std::stoi(k_per_block_str); const std::size_t k_per_block = std::stoi(k_per_block_str);
const std::size_t grid_size = GetGridSize(M, N, m_per_block, n_per_block); const std::size_t grid_size = GetGridSize(M, N, m_per_block, n_per_block);
params[gemm_spec_idx] = GetGemmSpec(M, N, K, m_per_block, n_per_block, k_per_block); params[gemm_spec_idx] = GetGemmSpec(M, N, K, m_per_block, n_per_block, k_per_block);
std::string str = std::accumulate( std::string str = std::accumulate(
params.begin() + 1, params.begin() + 1,
params.end(), params.end(),
std::string{}, std::string{},
[](const std::string& a, const std::string& b) { return a.empty() ? b : a + ", " + b; }); [](const std::string& a, const std::string& b) { return a.empty() ? b : a + ", " + b; });
str = params.front() + "< " + str + ">"; str = params.front() + "< " + str + ">";
if (params.back().find("v2") != std::string::npos and K % k_per_block != 0) if(params.back().find("v2") != std::string::npos and K % k_per_block != 0)
str = ""; str = "";
return Solution{str, block_size, grid_size}; return Solution{str, block_size, grid_size};
...@@ -159,7 +162,7 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch) const ...@@ -159,7 +162,7 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
for(std::size_t i = 0; i < num_instances; ++i) for(std::size_t i = 0; i < num_instances; ++i)
{ {
auto solution = MakeSolution(i, arch); auto solution = MakeSolution(i, arch);
if (solution.template_str != "") if(solution.template_str != "")
solutions.push_back(solution); solutions.push_back(solution);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment