Unverified Commit 9eb97340 authored by Mingtao Gu's avatar Mingtao Gu Committed by GitHub
Browse files

Merge pull request #1 from ROCm/i4_update

extend support KPerBlock <= ScaleBlockK
parents 40054f53 1a324dfb
...@@ -31,6 +31,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3) ...@@ -31,6 +31,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3)
add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp) add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp)
add_example_executable(example_gemm_xdl_fp16_pk_i4_v3 gemm_xdl_fp16_pk_i4_v3.cpp) add_example_executable(example_gemm_xdl_fp16_pk_i4_v3 gemm_xdl_fp16_pk_i4_v3.cpp)
add_example_executable(example_gemm_xdl_fp16_pk_i4_v3_b_scale gemm_xdl_fp16_pk_i4_v3_b_scale.cpp) add_example_executable(example_gemm_xdl_fp16_pk_i4_v3_b_scale gemm_xdl_fp16_pk_i4_v3_b_scale.cpp)
target_compile_options(example_gemm_xdl_fp16_pk_i4_v3_b_scale PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3)
add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp) add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3)
......
...@@ -27,7 +27,7 @@ static constexpr bool PermuteB = false; ...@@ -27,7 +27,7 @@ static constexpr bool PermuteB = false;
static constexpr ck::index_t Scale_Block_N = 1; static constexpr ck::index_t Scale_Block_N = 1;
static constexpr ck::index_t Scale_Block_K = 128; static constexpr ck::index_t Scale_Block_K = 128;
static constexpr ck::index_t KPerBlock = 128; static constexpr ck::index_t KPerBlock = 64;
// clang-format off // clang-format off
using DeviceGemmV2Instance = using DeviceGemmV2Instance =
...@@ -35,29 +35,16 @@ using DeviceGemmV2Instance = ...@@ -35,29 +35,16 @@ using DeviceGemmV2Instance =
ALayout, BLayout, CLayout, ALayout, BLayout, CLayout,
ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType, ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault, AElementOp, BElementOp, CElementOp, GemmDefault,
#if 0
128, Scale_Block_N, Scale_Block_K,
16, 128,
KPerBlock, 8, 32,
16, 16,
1, 4,
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4,
#else
256, Scale_Block_N, Scale_Block_K, 256, Scale_Block_N, Scale_Block_K,
128, 128, 128, 128,
KPerBlock, 8, 32, KPerBlock, 8, 32,
32, 32, 32, 32,
2, 2, 2, 2,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0, 2, 8, 8, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0, 2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4, 1, 1, S<1, 32, 1, 8>, 8,
#endif
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, false, PermuteB>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, false, PermuteB>;
// clang-format on // clang-format on
......
...@@ -328,7 +328,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -328,7 +328,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf, b_scale_grid_buf,
b_scale_thread_desc, b_scale_thread_desc,
make_tuple(n0, I0), make_tuple(n0, I0, I0),
b_scale_thread_buf); b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
...@@ -455,7 +455,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -455,7 +455,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_scale_thread_copy.Run(b_scale_grid_desc, b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf, b_scale_grid_buf,
b_scale_thread_desc, b_scale_thread_desc,
make_tuple(n0, I0), make_tuple(n0, I0, I0),
b_scale_thread_buf); b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow( b_scale_thread_copy.MoveSrcSliceWindow(
......
...@@ -713,7 +713,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -713,7 +713,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
if constexpr(ABlockLdsExtraM) if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number), make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
...@@ -849,7 +849,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -849,7 +849,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
__device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{ {
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
if constexpr(BBlockLdsExtraN) if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number), make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
...@@ -1303,8 +1303,11 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1303,8 +1303,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
// B Scale grid and buffer // B Scale grid and buffer
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)), math::integer_divide_ceil(problem.K, ScaleBlockK),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); ScaleBlockK),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK),
1,
0));
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
...@@ -1435,11 +1438,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1435,11 +1438,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
KPerBlock); KPerBlock);
// b scale // b scale
static_assert(KPerBlock<=ScaleBlockK);
const index_t ScaleSliceSizeN = NXdlPerWave; const index_t ScaleSliceSizeN = NXdlPerWave;
const index_t ScaleSliceSizeK = 1; const index_t ScaleSliceSizeK = 1;
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{})); make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}, Number<1>{}));
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
...@@ -1451,17 +1455,18 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1451,17 +1455,18 @@ struct GridwiseGemm_xdl_cshuffle_v3
BScaleType, BScaleType,
decltype(b_scale_grid_desc_bn_ak), decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_desc), decltype(b_scale_thread_desc),
Sequence<1, ScaleSliceSizeK>, Sequence<1, ScaleSliceSizeK, 1>,
Sequence<0, 1>, Sequence<0, 1, 2>,
1, 1,
1, 1,
1, 1,
false>( false>(
b_scale_grid_desc_bn_ak, b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset, 0)); make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset, 0, 0));
constexpr auto b_scale_thread_slice_copy_step = constexpr auto b_scale_thread_slice_copy_step =
make_tuple(make_multi_index(NWaves * NPerXdl, 0), make_multi_index(-NPerBlock, 1)); make_tuple(make_multi_index(NWaves * NPerXdl, 0, 0),
make_multi_index(-NPerBlock, KPerBlock/ScaleBlockK, KPerBlock%ScaleBlockK));
const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock; const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock;
...@@ -1478,13 +1483,11 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1478,13 +1483,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
b_block_buf, b_block_buf,
b_block_slice_copy_step, b_block_slice_copy_step,
c_thread_buf, c_thread_buf,
b_scale_grid_desc_bn_ak, b_scale_grid_desc_bn_ak,
b_scale_thread_desc, b_scale_thread_desc,
b_scale_thread_copy, b_scale_thread_copy,
b_scale_grid_buf, b_scale_grid_buf,
b_scale_thread_slice_copy_step, b_scale_thread_slice_copy_step,
num_k_block_main_loop, num_k_block_main_loop,
num_k_block_per_scale); num_k_block_per_scale);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment