Commit 2fe5c02a authored by aska-0096's avatar aska-0096
Browse files

Enable KPerblock=256

parent 72da1930
...@@ -27,7 +27,7 @@ static constexpr bool PermuteB = true; ...@@ -27,7 +27,7 @@ static constexpr bool PermuteB = true;
static constexpr ck::index_t Scale_Block_N = 1; static constexpr ck::index_t Scale_Block_N = 1;
static constexpr ck::index_t Scale_Block_K = 128; static constexpr ck::index_t Scale_Block_K = 128;
static constexpr ck::index_t KPerBlock = 64; static constexpr ck::index_t KPerBlock = 256;
// clang-format off // clang-format off
using DeviceGemmV2Instance = using DeviceGemmV2Instance =
...@@ -35,27 +35,27 @@ using DeviceGemmV2Instance = ...@@ -35,27 +35,27 @@ using DeviceGemmV2Instance =
ALayout, BLayout, CLayout, ALayout, BLayout, CLayout,
ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType, ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault, AElementOp, BElementOp, CElementOp, GemmDefault,
// 256, Scale_Block_N, Scale_Block_K,
// 128, 128,
// KPerBlock, 8, 32,
// 32, 32,
// 4, 1,
// S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
// 2, 8, 8, 0,
// S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>,
// 2, 32, 32, 0,
// 1, 1, S<1, 32, 1, 8>, 8,
256, Scale_Block_N, Scale_Block_K, 256, Scale_Block_N, Scale_Block_K,
128, 128, 16, 128,
KPerBlock, 8, 32, KPerBlock, 8, 32,
32, 32, 16, 16,
4, 1, 1, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0, 2, 8, 8, 0,
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0, 2, 32, 32, 0,
1, 1, S<1, 32, 1, 8>, 8, 1, 2, S<1, 16, 1, 16>, 8,
// 128, Scale_Block_N, Scale_Block_K,
// 16, 128,
// KPerBlock, 8, 32,
// 16, 16,
// 1, 4,
// S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
// 2, 8, 8, 0,
// S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
// 2, 32, 32, 0,
// 1, 1, S<1, 16, 1, 8>, 4,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, false, PermuteB>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, false, PermuteB>;
// clang-format on // clang-format on
...@@ -374,9 +374,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -374,9 +374,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
std::cout<<"scale_b1_k_n: "<<std::endl; std::cout<<"scale_b1_k_n: "<<std::endl;
for(int i = 0; i < N; i++) for(int i = 0; i < N; i++)
{ {
for(int j = 0; j < K; j++) for(int j = 0; j < (K + Scale_Block_K - 1) / Scale_Block_K; j++)
{ {
std::cout << ck::type_convert<float>(b1_k_n(i,j)) << ","; std::cout << ck::type_convert<float>(b1_k_n(j,i)) << ",";
} }
std::cout << std::endl; std::cout << std::endl;
} }
......
...@@ -346,6 +346,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -346,6 +346,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
} }
constexpr auto num_scale_k_block = BScaleThreadDesc{}.GetLength(I1); constexpr auto num_scale_k_block = BScaleThreadDesc{}.GetLength(I1);
constexpr auto num_scale_krepeat = KRepeat / num_scale_k_block;
// Local prefill 1 // Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
...@@ -373,13 +374,14 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -373,13 +374,14 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
a_thread_buf); a_thread_buf);
}); });
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(
make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}), b_block_desc_n0_n1_n2_k,
b_block_buf, make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
b_scale_thread_buf[Number<n0 * num_scale_k_block + k0/num_scale_k_block>{}], b_block_buf,
b_thread_desc_, b_scale_thread_buf[Number<n0 * num_scale_k_block + k0 / num_scale_krepeat>{}],
make_tuple(n0, I0, k0, I0), b_thread_desc_,
b_thread_buf); make_tuple(n0, I0, k0, I0),
b_thread_buf);
}); });
}); });
...@@ -469,7 +471,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -469,7 +471,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}), make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
b_block_buf, b_block_buf,
b_scale_thread_buf[Number<n0 * num_scale_k_block + k0/num_scale_k_block>{}], b_scale_thread_buf[Number<n0 * num_scale_k_block +
k0 / num_scale_krepeat>{}],
b_thread_desc_, b_thread_desc_,
make_tuple(n0, I0, k0, I0), make_tuple(n0, I0, k0, I0),
b_thread_buf); b_thread_buf);
......
...@@ -1424,16 +1424,24 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1424,16 +1424,24 @@ struct GridwiseGemm_xdl_cshuffle_v3
// b scale // b scale
// static_assert(KPerBlock <= ScaleBlockK); // static_assert(KPerBlock <= ScaleBlockK);
const index_t ScaleSliceSizeN = NXdlPerWave; static constexpr auto mfma = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>{};
const index_t ScaleSliceSizeK = (KPerBlock + ScaleBlockK - 1) / ScaleBlockK; static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
static constexpr auto KPerThread = KPerBlock / K0PerXdlops;
static constexpr auto ScaleSliceSizeN = NXdlPerWave;
static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK;
static constexpr auto KBlockScaleSliceSizeK = (KPerBlock + ScaleBlockK - 1) / ScaleBlockK;
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{})); make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
auto b_thread_offset = auto b_thread_offset_n =
get_thread_local_1d_id() % NPerXdl + (get_thread_local_1d_id() / 64) % NWaves * NPerXdl; get_thread_local_1d_id() % NPerXdl + (get_thread_local_1d_id() / 64) % NWaves * NPerXdl;
auto b_thread_offset_k = (get_thread_local_1d_id() % 64) / NPerXdl * KPerThread;
auto b_scale_thread_copy = auto b_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<BScaleType, ThreadwiseTensorSliceTransfer_v2<BScaleType,
...@@ -1443,16 +1451,17 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1443,16 +1451,17 @@ struct GridwiseGemm_xdl_cshuffle_v3
Sequence<1, ScaleSliceSizeK>, Sequence<1, ScaleSliceSizeK>,
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
1, ScaleSliceSizeK,
1, 1,
false>( false>(
b_scale_grid_desc_bn_ak, b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset, 0)); make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n,
b_thread_offset_k / ScaleBlockK));
constexpr auto b_scale_thread_slice_copy_step = constexpr auto b_scale_thread_slice_copy_step =
make_tuple(make_multi_index(NWaves * NPerXdl, 0), make_tuple(make_multi_index(NWaves * NPerXdl, 0),
make_multi_index(-NPerBlock, 0), make_multi_index(-NPerBlock, 0),
make_multi_index(-NPerBlock, ScaleSliceSizeK)); make_multi_index(-NPerBlock, KBlockScaleSliceSizeK));
const index_t num_k_block_per_scale = (ScaleBlockK + KPerBlock - 1) / KPerBlock; const index_t num_k_block_per_scale = (ScaleBlockK + KPerBlock - 1) / KPerBlock;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment