"sgl-router/git@developer.sourcefind.cn:change/sglang.git" did not exist on "5d62b56f7e9b79e1fb5d00d50512da7e2d71d481"
Commit d79d1a38 authored by Adam Osewski's avatar Adam Osewski
Browse files

Transpose A/B register tile if needed for comp v3 pipeline.

parent 69b6d2ab
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -245,11 +245,24 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -245,11 +245,24 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
"A/B Dram block window should have the same data type as appropriate " "A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"); "([A|B]DataType) defined in Problem definition!");
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && constexpr bool is_a_col_major =
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}], constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"); static_assert((is_a_col_major ?
(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}]) &&
:
(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}]
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert((is_b_row_major ?
(KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}]
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) &&
:
(NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}]),
"B block window has incorrect lengths for defined BLayout!");
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// Definitions of all needed tiles // Definitions of all needed tiles
...@@ -284,23 +297,52 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -284,23 +297,52 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
ABlockTile a_block_tile; ABlockTile a_block_tile;
BBlockTile b_block_tile; BBlockTile b_block_tile;
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// ----------------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------------
// Gemm pipeline start // Gemm pipeline start
// prefetch // prefetch
// global read 0 // global read 0
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window); Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window); Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
// initialize C // initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0 // LDS write 0
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window); Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window); Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds(); block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
...@@ -315,11 +357,31 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -315,11 +357,31 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{ {
block_sync_lds(); block_sync_lds();
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window); Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window); Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment