Commit 328cc7f4 authored by rocking's avatar rocking
Browse files

Hard code the vector dim

parent eead0864
...@@ -60,11 +60,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -60,11 +60,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// clang-format off // clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayernorm_Xdl_CShuffle using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayernorm_Xdl_CShuffle
//######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| Layernorm| Layernorm| Layernorm| Layernorm| Layernorm| //######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| Layernorm| Layernorm| Layernorm| Layernorm|
//######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| ThreadClusterSize| ThreadSliceSize| ESrcHDst| ESrc| HDst| GammaSrc| BetaSrc| //######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| ThreadClusterSize| ThreadSliceSize| ESrc| HDst| GammaSrc| BetaSrc|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N| _M_N| _M_N| VectorDim| VectorSize| VectorSize| VectorSize| VectorSize| //######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N| _M_N| _M_N| VectorSize| VectorSize| VectorSize| VectorSize|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, S<8, 32>, S<1, 8>, 1, 8, 8, 8, 8>; < ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, S<8, 32>, S<1, 8>, 8, 8, 8, 8>;
// clang-format on // clang-format on
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
......
...@@ -223,7 +223,6 @@ template <typename ALayout, ...@@ -223,7 +223,6 @@ template <typename ALayout,
index_t PostShuffleScalarPerVector, index_t PostShuffleScalarPerVector,
typename LayernormThreadClusterSize_M_N, typename LayernormThreadClusterSize_M_N,
typename LayernormThreadSliceSize_M_N, typename LayernormThreadSliceSize_M_N,
index_t LayernormESrcHDstVectorDim,
index_t LayernormESrcVectorSize, index_t LayernormESrcVectorSize,
index_t LayernormHDstVectorSize, index_t LayernormHDstVectorSize,
index_t LayernormGammaSrcVectorSize, index_t LayernormGammaSrcVectorSize,
...@@ -485,7 +484,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -485,7 +484,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
LayernormThreadClusterSize_M_N::At(I1), LayernormThreadClusterSize_M_N::At(I1),
LayernormThreadSliceSize_M_N::At(I0), LayernormThreadSliceSize_M_N::At(I0),
LayernormThreadSliceSize_M_N::At(I1), LayernormThreadSliceSize_M_N::At(I1),
LayernormESrcHDstVectorDim,
LayernormESrcVectorSize, LayernormESrcVectorSize,
LayernormHDstVectorSize, LayernormHDstVectorSize,
LayernormGammaSrcVectorSize, LayernormGammaSrcVectorSize,
...@@ -908,7 +906,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -908,7 +906,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
// check vector store of E // check vector store of E
// only support RowMajor for now // only support RowMajor for now
if constexpr(is_same_v<ELayout, Row>) if constexpr(is_same_v<ELayout, Row> && is_same_v<HLayout, Row>)
{ {
if(arg.NRaw_ % PostShuffleScalarPerVector != 0) if(arg.NRaw_ % PostShuffleScalarPerVector != 0)
{ {
......
...@@ -35,20 +35,18 @@ template <typename EDataType, ...@@ -35,20 +35,18 @@ template <typename EDataType,
index_t NThreadClusterSize, index_t NThreadClusterSize,
index_t MThreadSliceSize, index_t MThreadSliceSize,
index_t NThreadSliceSize, index_t NThreadSliceSize,
index_t ESrcHDstVectorDim,
index_t ESrcVectorSize, index_t ESrcVectorSize,
index_t HDstVectorSize, index_t HDstVectorSize,
index_t GammaSrcVectorSize, index_t GammaSrcVectorSize,
index_t BetaSrcVectorSize> index_t BetaSrcVectorSize>
struct GridwiseWelfordSecondHalfLayernorm2d struct GridwiseWelfordSecondHalfLayernorm2d
{ {
// TODO - Support ESrcHDstVectorDim == 0 static_assert(NThreadSliceSize % ESrcVectorSize == 0 &&
static_assert(ESrcHDstVectorDim == 1 && NThreadSliceSize % ESrcVectorSize == 0 &&
NThreadSliceSize % GammaSrcVectorSize == 0 && NThreadSliceSize % GammaSrcVectorSize == 0 &&
NThreadSliceSize % BetaSrcVectorSize == 0, NThreadSliceSize % BetaSrcVectorSize == 0,
"Invalid thread slice sizes and/or vector sizes configuration, please check!"); "Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(ESrcHDstVectorDim == 1 && NThreadSliceSize % HDstVectorSize == 0, static_assert(NThreadSliceSize % HDstVectorSize == 0,
"Invalid thread slice sizes and/or vector sizes configuration, please check!"); "Invalid thread slice sizes and/or vector sizes configuration, please check!");
using ThreadClusterLengths_M_N = Sequence<MThreadClusterSize, NThreadClusterSize>; using ThreadClusterLengths_M_N = Sequence<MThreadClusterSize, NThreadClusterSize>;
...@@ -227,7 +225,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -227,7 +225,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
decltype(thread_buffer_desc_m_n), decltype(thread_buffer_desc_m_n),
ThreadBufferLengths_M_N, ThreadBufferLengths_M_N,
ThreadBufferDimAccessOrder, ThreadBufferDimAccessOrder,
ESrcHDstVectorDim, 1, // SrcVectorDim
ESrcVectorSize, ESrcVectorSize,
1, 1,
true>( true>(
...@@ -270,7 +268,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d ...@@ -270,7 +268,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
HElementwiseOperation, HElementwiseOperation,
ThreadBufferLengths_M_N, ThreadBufferLengths_M_N,
ThreadBufferDimAccessOrder, ThreadBufferDimAccessOrder,
ESrcHDstVectorDim, 1, // DstVectorDim
HDstVectorSize, HDstVectorSize,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment