"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "e94e60d6fbb39d967638347c01a711cbe82e2c42"
Commit ccf94638 authored by Mateusz Ozga's avatar Mateusz Ozga
Browse files

Pass 4d sequence and convert to 3d

parent 860433ea
......@@ -48,16 +48,16 @@ using DeviceConvBwdWeightInstance =
16, // NPerXdl
1, // MXdlPerWave
1, // NXdlPerWave
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
1, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
4, // ABlockTransferDstScalarPerVector_K1
false, // ABlockLdsAddExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1
......
......@@ -47,16 +47,16 @@ using DeviceConvBwdWeightInstance =
32, // NPerXdl
2, // MXdlPerWave
2, // NXdlPerWave
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
2, // ABlockTransferDstScalarPerVector_K1
false, // ABlockLdsAddExtraM
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
2, // BBlockTransferDstScalarPerVector_K1
......
......@@ -49,16 +49,16 @@ using DeviceConvBwdWeightInstance =
16, // NPerXdl
1, // MXdlPerWave
1, // NXdlPerWave
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
1, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
4, // ABlockTransferDstScalarPerVector_K1
4, // ABlockTranstest/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cppferDstScalarPerVector_K1
false, // ABlockLdsAddExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1
......
......@@ -315,14 +315,43 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
batch);
}
template <typename SeqType>
constexpr static auto
ShuffleSequenceAndTransformFrom4DTo3D() noexcept(noexcept(SeqType{}.Size() == 4))
-> decltype(auto)
{
// Remove first element and,
// Convert 4d->3d sequence.
constexpr auto _I0 = SeqType{}.At(I1);
constexpr auto _I1 = SeqType{}.At(I2);
constexpr auto _I2 = SeqType{}.At(I0);
constexpr auto _Seq = S<_I0, _I1, _I2>();
return _Seq;
}
template <typename SeqType>
constexpr static auto
TransformSequenceFrom4DTo3dAndReduceByOne() noexcept(noexcept(SeqType{}.Size() == 4))
-> decltype(auto)
{
// Skip first element and
// Convert 4d->3d and take away one from seq.
constexpr index_t one = 1;
constexpr auto _I0 = SeqType{}.At(I1) - one;
constexpr auto _I1 = SeqType{}.At(I2) - one;
constexpr auto _I2 = SeqType{}.At(I3) - one;
constexpr auto _Seq = S<_I0, _I1, _I2>();
return _Seq;
}
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
using GridwiseGemm =
GridwiseGemm_xdl_cshuffle_v3<tensor_layout::gemm::RowMajor,
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<
tensor_layout::gemm::RowMajor,
tensor_layout::gemm::ColumnMajor,
tensor_layout::gemm::RowMajor,
ADataType,
......@@ -344,17 +373,21 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
NPerXdl,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
decltype(ShuffleSequenceAndTransformFrom4DTo3D<
ABlockTransferThreadClusterLengths_K0_M_K1>()),
decltype(TransformSequenceFrom4DTo3dAndReduceByOne<
ABlockTransferThreadClusterArrangeOrder>()),
decltype(TransformSequenceFrom4DTo3dAndReduceByOne<ABlockTransferSrcAccessOrder>()),
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
decltype(ShuffleSequenceAndTransformFrom4DTo3D<
BBlockTransferThreadClusterLengths_K0_N_K1>()),
decltype(TransformSequenceFrom4DTo3dAndReduceByOne<
BBlockTransferThreadClusterArrangeOrder>()),
decltype(TransformSequenceFrom4DTo3dAndReduceByOne<BBlockTransferSrcAccessOrder>()),
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
......
......@@ -201,16 +201,16 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
16, // NPerXdl
1, // MXdlPerWave
1, // NXdlPerWave
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
1, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
4, // ABlockTransferDstScalarPerVector_K1
false, // ABlockLdsAddExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment