"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "7bc1dae095efb910345602235394dfd74f391e92"
Commit cae981d5 authored by Chao Liu's avatar Chao Liu
Browse files

clean

parent 1a0d27b0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
......@@ -5,13 +8,13 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp"
namespace ck {
template <typename GridwiseGemm,
......@@ -156,7 +159,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
__device__ static index_t GetThreadId() { return get_thread_local_1d_id(); }
};
using CShuffleBlockTransferThreadGroup = ThisThreadBlock<TileMathThreadGroupSize>;
// load and math+store Wave pipelines.
// TODO: build pipelines blocks scheduling parallel tasks
using GridwiseGemmLoad = GridwiseGemmLoadWave<TileLoadThreadGroup, NumGemmKPrefetchStage>;
......@@ -220,6 +225,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
sizeof(FloatAB),
c_block_size * sizeof(FloatCShuffle));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
......@@ -251,6 +257,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
......@@ -377,6 +384,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
// divide block work by [M, N]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
......@@ -479,7 +487,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
}
else if(TileMathThreadGroup::IsBelong())
{
// branch early for math wave
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment