Commit 441d8973 authored by rocking's avatar rocking
Browse files

Share the vector size

parent c83bad81
......@@ -60,11 +60,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayernorm_Xdl_CShuffle
//######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| Layernorm| Layernorm| Layernorm|
//######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| ThreadClusterSize| ThreadSliceSize| HDst| GammaSrc| BetaSrc|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N| _M_N| _M_N| VectorSize| VectorSize| VectorSize|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<8, 32>, S<1, 8>, 8, 8, 8>;
//######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm|
//######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| ThreadClusterSize| ThreadSliceSize|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N| _M_N| _M_N|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<8, 32>, S<1, 8>>;
// clang-format on
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
......
......@@ -231,9 +231,6 @@ template <typename ALayout,
index_t PostShuffleScalarPerVector,
typename LayernormThreadClusterSize_M_N,
typename LayernormThreadSliceSize_M_N,
index_t LayernormHDstVectorSize,
index_t LayernormGammaSrcVectorSize,
index_t LayernormBetaSrcVectorSize,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1>
struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
......@@ -259,7 +256,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
using VarDataType = CShuffleDataType;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr index_t LayernormESrcVectorSize = LayernormHDstVectorSize;
static constexpr index_t LayernormHDstVectorSize = PostShuffleScalarPerVector;
static constexpr index_t LayernormGammaSrcVectorSize = PostShuffleScalarPerVector;
static constexpr index_t LayernormBetaSrcVectorSize = PostShuffleScalarPerVector;
static constexpr index_t LayernormESrcVectorSize = PostShuffleScalarPerVector;
using LayernormBlockTileSize_M_N =
Sequence<LayernormThreadClusterSize_M_N::At(0) * LayernormThreadSliceSize_M_N::At(0),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment