"...csrc/git@developer.sourcefind.cn:change/sglang.git" did not exist on "4fc09e0df0f018d74bbc99e16e11e9530d198b1e"
Commit ad2f82ac authored by rocking's avatar rocking
Browse files

Sync code, prepare to test on MI200

parent 1d7290fb
......@@ -284,7 +284,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
}
template <typename LayOut>
static auto MakeGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t Stride)
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t Stride)
{
const auto grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, LayOut>::value)
......@@ -308,11 +308,19 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumDTensor>{});
}
static auto MakeMeanVarCountGridDescriptor_M_NBlock(index_t M, index_t NBlock)
{
const auto grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, NBlock));
// TODO - padding according to MNperBlock of Gemm and Layernorm
return grid_desc_m_n;
}
static auto MakeDescriptor_M(index_t MRaw)
{
const auto grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
......@@ -366,9 +374,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using MeanVarCountGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using MeanVarCountGridDesc_M_N = decltype(MakeMeanVarCountGridDescriptor_M_NBlock(1, 1));
using GammaBetaGridDesc_N = decltype(MakeDescriptor_N(1));
using EHGridDesc_M_N = decltype(MakeGridDescriptor_M_N<HLayout>(1, 1, 1));
using EHGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<HLayout>(1, 1, 1));
using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
......@@ -479,11 +487,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideH)},
mean_var_count_grid_desc_m_n_{},
gamma_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
beta_grid_desc_n_{DeviceOp::MakeDescriptor_N(NRaw)},
h_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
h_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
a_grid_desc_ak0_m_ak1_{
GridwiseGemmWelford::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{
......@@ -497,7 +505,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
epsilon_{epsilon}
{
mean_var_count_grid_desc_m_n_ =
DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, gemm_nblock_, gemm_nblock_);
DeviceOp::MakeMeanVarCountGridDescriptor_M_NBlock(MRaw, gemm_nblock_);
int s = mean_var_count_grid_desc_m_n_.GetElementSpaceSize();
printf("mean_var_count_grid_desc_m_n.GetElementSpaceSize() = %d\n", s);
hip_check_error(hipMalloc(&p_e_grid_, sizeof(EDataType) * MRaw * NRaw));
......@@ -518,7 +529,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
// D desc
ds_grid_desc_m_n_(i) =
DeviceOp::MakeGridDescriptor_M_N<DLayout>(MRaw, NRaw, StrideDs[i]);
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaw, NRaw, StrideDs[i]);
});
// populate desc for Ds/E/F/G
......@@ -526,7 +537,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
mean_var_count_grid_desc_m_n_,
block_2_etile_map_))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
......@@ -612,7 +622,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.mean_var_count_grid_desc_m_n_,
arg.block_2_etile_map_))
{
throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting");
......@@ -694,7 +703,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
LayernormThreadClusterSize_M_N::At(I1)) /
LayernormThreadClusterSize_M_N::At(I1);
index_t numXBlockTileIteration_N =
index_t numEBlockTileIteration_N =
math::integer_least_multiple(N, LayernormBlockTileSize_M_N::At(I1)) /
LayernormBlockTileSize_M_N::At(I1);
......@@ -717,7 +726,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.beta_grid_desc_n_,
arg.gemm_nblock_,
numMeanVarCountBlockTileIteration_N,
numXBlockTileIteration_N,
numEBlockTileIteration_N,
arg.epsilon_);
return avg_time;
......
......@@ -269,13 +269,11 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2ETileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
const MeanVarCountGridDesc_M_N& mean_var_count_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
......@@ -286,9 +284,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
const auto K = a_grid_desc_m_k.GetLength(I1);
// check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
M == mean_var_count_grid_desc_m_n.GetLength(I0) &&
N / NPerBlock == mean_var_count_grid_desc_m_n.GetLength(I1)))
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)))
{
return false;
}
......@@ -997,6 +993,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) {
block_sync_lds();
count_thread_buf = threadwise_welfords(i).cur_count_;
BlockwiseWelford::Run(
mean_thread_buf(j), var_thread_buf(j), count_thread_buf(j));
});
......@@ -1083,6 +1080,12 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
count_thread_buf,
mean_var_count_grid_desc_mblock_mperblock_nblock,
welford_count_grid_buf);
float mean = static_cast<float>(mean_thread_buf(I0));
float var = static_cast<float>(var_thread_buf(I0));
int count = count_thread_buf(I0);
if(i == 0 && get_thread_global_1d_id() == 0)
printf("1st kernel mean = %f, var = %f, count = %d\n", mean, var, count);
});
} // shuffle C + Ds + welford + write out
......
......@@ -34,15 +34,23 @@ template <typename EDataType,
index_t NThreadClusterSize,
index_t MThreadSliceSize,
index_t NThreadSliceSize,
index_t ESrcYDstVectorDim,
index_t ESrcHDstVectorDim,
index_t ESrcVectorSize,
index_t YDstVectorSize,
index_t HDstVectorSize,
index_t GammaSrcVectorSize,
index_t BetaSrcVectorSize,
index_t MeanVarSrcDstVectorSize>
struct GridwiseWelfordSecondHalfLayernorm2d
{
static constexpr bool reorder_thread_cluster = (ESrcYDstVectorDim == 0);
static_assert((ESrcHDstVectorDim == 0 && MThreadSliceSize % ESrcVectorSize == 0) ||
(ESrcHDstVectorDim == 1 && NThreadSliceSize % ESrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert((ESrcHDstVectorDim == 0 && MThreadSliceSize % HDstVectorSize == 0) ||
(ESrcHDstVectorDim == 1 && NThreadSliceSize % HDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (ESrcHDstVectorDim == 0);
using ThreadClusterLengths_M_N = Sequence<MThreadClusterSize, NThreadClusterSize>;
......@@ -73,8 +81,14 @@ struct GridwiseWelfordSecondHalfLayernorm2d
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t N_BlockTileSize = NThreadClusterSize * NThreadSliceSize;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t N_BlockTileSize = NThreadClusterSize * NThreadSliceSize;
static constexpr index_t N_BlockTileStepSize = NThreadClusterSize * ESrcVectorSize;
static constexpr auto EThreadBufferNumber = Number<NThreadSliceSize / ESrcVectorSize>{};
static constexpr auto GammaThreadBufferNumber = Number<NThreadSliceSize / ESrcVectorSize>{};
static constexpr auto BetaThreadBufferNumber = Number<NThreadSliceSize / ESrcVectorSize>{};
static constexpr auto HThreadBufferNumber = Number<NThreadSliceSize / ESrcVectorSize>{};
__device__ static void Run(const EDataType* __restrict__ p_e_grid,
const MeanDataType* __restrict__ p_in_welford_mean_grid,
......@@ -89,8 +103,8 @@ struct GridwiseWelfordSecondHalfLayernorm2d
const GammaBetaGridDesc_N& gamma_grid_desc_n,
const GammaBetaGridDesc_N& beta_grid_desc_n,
index_t gemm_nblock_,
index_t num_mean_var_count_k_block_tile_iteration,
index_t num_xy_k_block_tile_iteration,
index_t numMeanVarCountBlockTileIteration_N,
index_t numEBlockTileIteration_N,
ComputeDataType epsilon)
{
ignore = p_e_grid;
......@@ -106,10 +120,206 @@ struct GridwiseWelfordSecondHalfLayernorm2d
ignore = gamma_grid_desc_n;
ignore = beta_grid_desc_n;
ignore = gemm_nblock_;
ignore = num_mean_var_count_k_block_tile_iteration;
ignore = num_xy_k_block_tile_iteration;
ignore = numMeanVarCountBlockTileIteration_N;
ignore = numEBlockTileIteration_N;
ignore = epsilon;
// float mean = static_cast<float>(p_in_welford_mean_grid[0]);
// float var = static_cast<float>(p_in_welford_var_grid[0]);
// int count = p_in_welford_count_grid[0];
// if(get_thread_global_1d_id() == 0)
// printf("kernel mean = %f, var = %f, count = %d\n", mean, var, count);
float mean = static_cast<float>(p_in_welford_mean_grid[0]);
if(get_thread_global_1d_id() == 0)
printf("mean = %f\n", mean);
int s = static_cast<int>(mean_var_count_grid_desc_m_n.GetElementSpaceSize());
if(get_thread_global_1d_id() == 0)
printf("mean_var_count_grid_desc_m_n.GetElementSpaceSize() = %d\n", s);
// using ThreadBufferLengths_1_1 = Sequence<1, 1>;
// constexpr auto thread_buffer_desc_1_1 =
// make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
// constexpr auto grid_desc_1_1 =
// make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
// const auto mean_grid = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_in_welford_mean_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, 1, true> mean_thread;
// float mean1 = (mean_grid.template Get<MeanDataType>(0, true));
// if(get_thread_global_1d_id() == 0)
// printf("global mean = %f\n", mean1);
// auto threadwise_mean_load_m_k =
// ThreadwiseTensorSliceTransfer_v2<MeanDataType,
// ComputeDataType,
// decltype(mean_var_count_grid_desc_m_n),
// decltype(thread_buffer_desc_1_1),
// ThreadBufferLengths_1_1,
// Sequence<0, 1>,
// 1,
// 1,
// 1,
// true>(mean_var_count_grid_desc_m_n,
// make_multi_index(0, 0));
// threadwise_mean_load_m_k.Run(mean_var_count_grid_desc_m_n,
// mean_grid,
// thread_buffer_desc_1_1,
// make_tuple(I0, I0),
// mean_thread);
// if(get_thread_global_1d_id() == 0)
// printf("threadwise mean = %f\n", mean_thread(Number<0>{}));
// // Thread/Block id
// const index_t thread_local_id = get_thread_local_1d_id();
// const index_t block_global_id = get_block_1d_id();
// const auto thread_cluster_idx =
// thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
// const auto thread_m_cluster_id = thread_cluster_idx[I0];
// const auto thread_n_cluster_id = thread_cluster_idx[I1];
// // step1: Merge mean and variance
// using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
// constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
// make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
// auto threadwise_mean_load_m_k =
// ThreadwiseTensorSliceTransfer_v2<MeanDataType,
// ComputeDataType,
// MeanVarCountGridDesc_M_N,
// decltype(thread_buffer_desc_m_1),
// ThreadBufferLengths_M_1,
// Sequence<0, 1>,
// 1,
// 1,
// 1,
// true>(mean_var_count_grid_desc_m_n,
// make_multi_index(0, 0));
// auto threadwise_var_load_m_k =
// ThreadwiseTensorSliceTransfer_v2<VarDataType,
// ComputeDataType,
// MeanVarCountGridDesc_M_N,
// decltype(thread_buffer_desc_m_1),
// ThreadBufferLengths_M_1,
// Sequence<0, 1>,
// 1,
// 1,
// 1,
// true>(
// mean_var_count_grid_desc_m_n,
// make_multi_index(block_global_id * M_BlockTileSize +
// thread_m_cluster_id * MThreadSliceSize,
// thread_n_cluster_id));
// auto threadwise_count_load_m_k =
// ThreadwiseTensorSliceTransfer_v2<int32_t,
// int32_t,
// MeanVarCountGridDesc_M_N,
// decltype(thread_buffer_desc_m_1),
// ThreadBufferLengths_M_1,
// Sequence<0, 1>,
// 1,
// 1,
// 1,
// true>(
// mean_var_count_grid_desc_m_n,
// make_multi_index(block_global_id * M_BlockTileSize +
// thread_m_cluster_id * MThreadSliceSize,
// thread_n_cluster_id));
// const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_in_welford_mean_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
// const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_in_welford_var_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
// const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_in_welford_count_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
// in_welford_mean_thread_buf;
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
// in_welford_var_thread_buf;
// StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
// in_welford_count_thread_buf;
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
// welford_mean_thread_buf;
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
// welford_var_thread_buf;
// StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
// welford_count_thread_buf;
// constexpr auto mean_var_count_thread_copy_step_m_n =
// make_multi_index(0, NThreadClusterSize);
// static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
// welford_mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
// welford_var_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
// welford_count_thread_buf(I) = 0;
// });
// for(index_t reducedTiles = 0; reducedTiles < numMeanVarCountBlockTileIteration_N;
// ++reducedTiles)
// {
// threadwise_mean_load_m_k.Run(mean_var_count_grid_desc_m_n,
// welford_mean_global_val_buf,
// thread_buffer_desc_m_1,
// make_tuple(I0, I0),
// in_welford_mean_thread_buf);
// // threadwise_var_load_m_k.Run(mean_var_count_grid_desc_m_n,
// // welford_var_global_val_buf,
// // thread_buffer_desc_m_1,
// // make_tuple(I0, I0),
// // in_welford_var_thread_buf);
// // threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_n,
// // welford_count_global_val_buf,
// // thread_buffer_desc_m_1,
// // make_tuple(I0, I0),
// // in_welford_count_thread_buf);
// // ThreadwiseWelford::Run(in_welford_mean_thread_buf,
// // in_welford_var_thread_buf,
// // in_welford_count_thread_buf,
// // welford_mean_thread_buf,
// // welford_var_thread_buf,
// // welford_count_thread_buf);
// // threadwise_mean_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
// // mean_var_count_thread_copy_step_m_n);
// // threadwise_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
// // mean_var_count_thread_copy_step_m_n);
// // threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
// // mean_var_count_thread_copy_step_m_n);
// static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
// if(get_thread_global_1d_id() == 0)
// printf("mean = %f, var = %f, count = %d\n",
// in_welford_mean_thread_buf(I),
// in_welford_var_thread_buf(I),
// in_welford_count_thread_buf(I));
// });
// }
// static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
// if constexpr(I > 0)
// block_sync_lds();
// if(get_thread_global_1d_id() == 0)
// printf("count = %d\n", welford_count_thread_buf(I));
// BlockwiseWelford::Run(
// welford_mean_thread_buf(I), welford_var_thread_buf(I),
// welford_count_thread_buf(I));
// });
} // run
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment