"docs/git@developer.sourcefind.cn:change/sglang.git" did not exist on "d39899e85c5c29b3aeb2ea36d19f59214de60336"
Commit 6d9425ec authored by Jianfeng yan's avatar Jianfeng yan
Browse files

added block2CTileMap, but results are not correct

parent f4f94f70
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "device_base.hpp" #include "device_base.hpp"
#include "device_gemm.hpp" #include "device_gemm.hpp"
#include "common_header.hpp" #include "common_header.hpp"
#include "statically_indexed_array.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
...@@ -181,6 +182,51 @@ struct DeviceGroupedGemmXdl ...@@ -181,6 +182,51 @@ struct DeviceGroupedGemmXdl
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
template <int GroupCount>
struct Block2CTileMap
{
Block2CTileMap() = default;
template <typename GemmDesc>
// Block2CTileMap(const StaticallyIndexedArray<GemmDesc, GroupCount>& gemm_desc)
Block2CTileMap(const std::vector<GemmDesc>& gemm_desc, const index_t N0) : N0_{N0}
{
for(index_t grp = 0; grp < GroupCount-1; ++grp)
{
assert(gemm_desc[grp].BlockEnd == gemm_desc[grp+1].BlockStart);
}
for(index_t grp = 0; grp < GroupCount; ++grp)
{
block_ptr_[grp] = gemm_desc[grp].BlockStart;
}
block_ptr_[GroupCount] = gemm_desc[GroupCount-1].BlockEnd;
}
template <typename Index>
__host__ __device__ constexpr auto CalculateBottomIndex(Index blockIdx) const
{
index_t block_id = blockIdx[Number<0>{}];
index_t local_block_id;
for(index_t grp = 0; grp < MaxGroupCount; ++grp)
{
if(block_id >= block_ptr_[grp] && block_id < block_ptr_[grp+1])
{
local_block_id = block_id - block_ptr_[grp];
}
}
return make_tuple(local_block_id / N0_, local_block_id % N0_);
// return make_tuple(local_block_id % N0_, local_block_id / N0_);
}
private:
index_t block_ptr_[GroupCount + 1];
index_t N0_;
};
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
BlockSize, BlockSize,
...@@ -232,13 +278,14 @@ struct DeviceGroupedGemmXdl ...@@ -232,13 +278,14 @@ struct DeviceGroupedGemmXdl
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
const ADataType* a_ptr; const ADataType* a_ptr;
const BDataType* b_ptr; const BDataType* b_ptr;
CDataType* c_ptr; CDataType* c_ptr;
ck::index_t BlockStart, BlockEnd; ck::index_t BlockStart, BlockEnd;
// typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
Block2CTileMap<MaxGroupCount> block_2_ctile_map_;
}; };
// Argument // Argument
...@@ -301,15 +348,13 @@ struct DeviceGroupedGemmXdl ...@@ -301,15 +348,13 @@ struct DeviceGroupedGemmXdl
const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
const auto block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
gemm_desc_kernel_arg_.push_back( gemm_desc_kernel_arg_.push_back(
GemmDescKernelArg{a_grid_desc_k0_m_k1_, GemmDescKernelArg{a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_, b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_, c_grid_desc_m_n_,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
block_2_ctile_map_, // block_2_ctile_map_,
static_cast<const ADataType*>(p_a[i]), static_cast<const ADataType*>(p_a[i]),
static_cast<const BDataType*>(p_b[i]), static_cast<const BDataType*>(p_b[i]),
static_cast<CDataType*>(p_c[i]), static_cast<CDataType*>(p_c[i]),
...@@ -317,6 +362,11 @@ struct DeviceGroupedGemmXdl ...@@ -317,6 +362,11 @@ struct DeviceGroupedGemmXdl
BlockEnd}); BlockEnd});
} }
} }
for(index_t i = 0; i < gemm_shapes.size(); i++)
{
gemm_desc_kernel_arg_[i].block_2_ctile_map_ = Block2CTileMap<MaxGroupCount>{gemm_desc_kernel_arg_, gemm_desc_kernel_arg_[i].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_.GetLength(Number<1>{})};
}
} }
// private: // private:
...@@ -328,6 +378,7 @@ struct DeviceGroupedGemmXdl ...@@ -328,6 +378,7 @@ struct DeviceGroupedGemmXdl
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_; std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_;
// Block2CTileMap<MaxGroupCount> block_2_ctile_map_;
index_t grid_size_; index_t grid_size_;
}; };
......
...@@ -73,6 +73,7 @@ __global__ void ...@@ -73,6 +73,7 @@ __global__ void
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op) const CElementwiseOperation c_element_op)
{ {
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -84,7 +85,7 @@ __global__ void ...@@ -84,7 +85,7 @@ __global__ void
i < group_count) i < group_count)
{ {
auto group_id = i; auto group_id = i;
const index_t block_id_grp = block_id - gemm_desc_[group_id].BlockStart; // const index_t block_id_grp = block_id - gemm_desc_[group_id].BlockStart;
GridwiseGemm::template Run<HasMainK0BlockLoop>( GridwiseGemm::template Run<HasMainK0BlockLoop>(
gemm_desc_[group_id].a_ptr, gemm_desc_[group_id].a_ptr,
...@@ -97,8 +98,8 @@ __global__ void ...@@ -97,8 +98,8 @@ __global__ void
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
gemm_desc_[group_id].block_2_ctile_map_, gemm_desc_[group_id].block_2_ctile_map_);
block_id_grp); // block_2_ctile_map);
} }
}); });
#else #else
...@@ -426,8 +427,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -426,8 +427,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map)
ck::index_t block_id = get_block_1d_id()) // ck::index_t block_id = get_block_1d_id())
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
...@@ -440,7 +441,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -440,7 +441,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_id)); block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force m/n_block_data_idx_on_grid into SGPR // HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment