Commit f0fba871 authored by mtgu0705's avatar mtgu0705
Browse files

Added gemm_fp8xint4_Bpreshuffle files, function not checked yet

parent da8324b6
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp"
using F8 = ck::f8_t; using F8 = ck::f8_t;
using I4 = ck::pk_i4_t; using I4 = ck::pk_i4_t;
...@@ -63,7 +63,7 @@ static constexpr ck::index_t KPerBlock = 128; ...@@ -63,7 +63,7 @@ static constexpr ck::index_t KPerBlock = 128;
// clang-format off // clang-format off
using DeviceGemmV2Instance = using DeviceGemmV2Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3_BPreshuffle<
ALayout, BLayout, CLayout, ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault, AElementOp, BElementOp, CElementOp, GemmDefault,
...@@ -77,7 +77,7 @@ using DeviceGemmV2Instance = ...@@ -77,7 +77,7 @@ using DeviceGemmV2Instance =
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0, 2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4, 1, 1, S<1, 16, 1, 8>, 4,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ADataType, ADataType, PermuteA, PermuteB>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, F8, F8, PermuteA, PermuteB>;
// clang-format on // clang-format on
...@@ -174,8 +174,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -174,8 +174,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
// int NperXdl= 16; int NperXdl = GetPreShuffleParameters;
// preShuffleBuffer(b_k_n.mData.data(), b_k_n_preshuffled.mData.data(), N, K, NperXdl); preShuffleBuffer(b_k_n.mData.data(), b_k_n_preshuffled.mData.data(), N, K, NperXdl);
// weight permute // weight permute
if constexpr(PermuteB) if constexpr(PermuteB)
...@@ -190,7 +190,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -190,7 +190,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{ {
for(int jj = 0; jj < K1; jj++) for(int jj = 0; jj < K1; jj++)
{ {
b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n_preshuffled(i * K + (j * K1 + jj));
} }
} }
} }
...@@ -201,7 +201,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -201,7 +201,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{ {
for(int j = 0; j < K; j++) for(int j = 0; j < K; j++)
{ {
b_k_n_permute(i * K + j) = b_k_n(i * K + j); b_k_n_permute(i * K + j) = b_k_n_preshuffled(i * K + j);
} }
} }
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp"
namespace ck {
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
BlockGemmPipelineScheduler BlkGemmPipeSche,
index_t BlockSize,
typename ADataType,
typename BDataType,
typename ComputeDataType,
typename AccDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
constexpr auto BlockGemmBPreshufflePipeline_Selector()
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
return BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
return BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
return BlockwiseGemmXdlops_pipeline_v4<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v5)
{
return BlockwiseGemmXdlops_pipeline_v5<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else
{
std::cerr << "BlockGemmPipeline configuration is not available" << std::endl;
}
}
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -114,6 +114,38 @@ struct DeviceGemmV2BScale : public BaseOperator ...@@ -114,6 +114,38 @@ struct DeviceGemmV2BScale : public BaseOperator
virtual ck::index_t GetKPerBlock() = 0; virtual ck::index_t GetKPerBlock() = 0;
}; };
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemmV2BPreshuffle : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t KSplit,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual bool GetPermuteB() = 0;
virtual ck::index_t GetKPerBlock() = 0;
};
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment