Commit 786a0faa authored by Jing Zhang's avatar Jing Zhang
Browse files

add permute switch as a template

parent 6a2521ea
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
using ADataType = ck::half_t; using ADataType = ck::half_t;
using BDataType = ck::f8_t; using BDataType = ck::f8_t;
using AccDataType = float; using AccDataType = float;
using CShuffleDataType = float; using CShuffleDataType = ck::half_t;
using CDataType = ck::half_t; using CDataType = ck::half_t;
using ALayout = Row; using ALayout = Row;
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
using ADataType = ck::half_t; using ADataType = ck::half_t;
using BDataType = ck::pk_i4_t; using BDataType = ck::pk_i4_t;
using AccDataType = float; using AccDataType = float;
using CShuffleDataType = float; using CShuffleDataType = ck::half_t;
using CDataType = ck::half_t; using CDataType = ck::half_t;
using ALayout = Row; using ALayout = Row;
...@@ -21,6 +21,8 @@ using CElementOp = PassThrough; ...@@ -21,6 +21,8 @@ using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr bool PermuteB = true;
// clang-format off // clang-format off
using DeviceGemmV2Instance = using DeviceGemmV2Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
...@@ -38,7 +40,7 @@ using DeviceGemmV2Instance = ...@@ -38,7 +40,7 @@ using DeviceGemmV2Instance =
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 1, 2, 32, 32, 1,
1, 1, S<1, 16, 1, 4>, 4, 1, 1, S<1, 16, 1, 4>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>; ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, CDataType, CDataType, false, PermuteB>;
[[maybe_unused]] static int KPerBlock = 256; [[maybe_unused]] static int KPerBlock = 256;
#else #else
...@@ -52,7 +54,7 @@ using DeviceGemmV2Instance = ...@@ -52,7 +54,7 @@ using DeviceGemmV2Instance =
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0, 2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4, 1, 1, S<1, 16, 1, 8>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v2, CDataType, CDataType, false, PermuteB>;
[[maybe_unused]]static int KPerBlock = 128; [[maybe_unused]]static int KPerBlock = 128;
#endif #endif
...@@ -123,7 +125,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -123,7 +125,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1}); b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break; break;
case 1: case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2}); b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
break; break;
case 2: case 2:
...@@ -136,7 +138,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -136,7 +138,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
} }
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
...@@ -150,8 +152,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -150,8 +152,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
// weight permute // weight permute
#if 1 if constexpr(PermuteB)
{
int K1 = KPerBlock; int K1 = KPerBlock;
int K0 = K / KPerBlock; int K0 = K / KPerBlock;
...@@ -166,8 +169,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -166,8 +169,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
} }
} }
} }
}
#else else
{
for(int i = 0; i < N; i++) for(int i = 0; i < N; i++)
{ {
for(int j = 0; j < K; j++) for(int j = 0; j < K; j++)
...@@ -175,7 +179,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -175,7 +179,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
b_k_n_permute(i * K + j) = b_k_n(i * K + j); b_k_n_permute(i * K + j) = b_k_n(i * K + j);
} }
} }
#endif }
// vector pk_i4x4 permute // vector pk_i4x4 permute
for(int i = 0; i < N; i++) for(int i = 0; i < N; i++)
......
...@@ -64,7 +64,9 @@ template <typename ALayout, ...@@ -64,7 +64,9 @@ template <typename ALayout,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = CDataType, typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA> typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
BLayout, BLayout,
CLayout, CLayout,
...@@ -122,7 +124,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -122,7 +124,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
BlkGemmPipeSched, BlkGemmPipeSched,
BlkGemmPipelineVer, BlkGemmPipelineVer,
ComputeTypeA, ComputeTypeA,
ComputeTypeB>; ComputeTypeB,
PermuteA,
PermuteB>;
using Argument = typename GridwiseGemm::Argument; using Argument = typename GridwiseGemm::Argument;
......
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#define WEIGHT_PERMUTE
namespace ck { namespace ck {
// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
...@@ -129,7 +127,9 @@ template <typename ALayout, ...@@ -129,7 +127,9 @@ template <typename ALayout,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4, BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
typename ComputeTypeA = CDataType, typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA> typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
struct GridwiseGemm_xdl_cshuffle_v3 struct GridwiseGemm_xdl_cshuffle_v3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -389,7 +389,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -389,7 +389,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
} }
else else
{ {
#ifndef WEIGHT_PERMUTE if constexpr(!PermuteB)
{
// not pad N or K // not pad N or K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw, b_grid_desc_nraw_kraw,
...@@ -399,9 +400,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -399,9 +400,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
#else }
else
{
// Weight Tile Permute // Weight Tile Permute
constexpr index_t BK01 = KPerBlock / BK1Value; constexpr index_t BK01 = KPerBlock / BK1Value;
// const index_t BK00 = BK0 / BK01;
const index_t BK0_ = StrideB / BK1Value; const index_t BK0_ = StrideB / BK1Value;
const index_t BK00 = BK0_ / BK01; const index_t BK00 = BK0_ / BK01;
...@@ -417,7 +421,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -417,7 +421,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return b_grid_desc_bk0_n_bk1_permute; return b_grid_desc_bk0_n_bk1_permute;
#endif }
} }
} }
...@@ -621,12 +625,15 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -621,12 +625,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{ {
#ifndef WEIGHT_PERMUTE if constexpr(!PermuteB)
{
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
#else }
else
{
const int k0_offset = karg.KRead * karg.N; const int k0_offset = karg.KRead * karg.N;
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
#endif }
} }
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1)) if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment