Commit f0d626c3 authored by rtmadduri's avatar rtmadduri
Browse files

Fix threadslicing incorrect dims

parent e5d6cf9c
...@@ -54,7 +54,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdlSpl ...@@ -54,7 +54,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdlSpl
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 3, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 3, 8, 8, 1, 1, 1, S<32, 1, 8>, 8>; < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 3, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on // clang-format on
#include "run_grouped_gemm_example.inc" #include "run_grouped_gemm_example.inc"
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -170,18 +170,18 @@ struct DeviceGroupedGemmSplitkInstanceWrapper ...@@ -170,18 +170,18 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
using I = ck::Number<N>; using I = ck::Number<N>;
using ABlockTransferThreadClusterArrageOrder = using ABlockTransferThreadClusterArrageOrder =
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>; std::conditional_t<std::is_same_v<ALayout, Row>, S<1, 0, 2>, S<0, 2, 1>>;
using ABlockTransferSrcAccessOrder = using ABlockTransferSrcAccessOrder =
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>; std::conditional_t<std::is_same_v<ALayout, Row>, S<1, 0, 2>, S<0, 2, 1>>;
using ABlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<ALayout, Row>, I<3>, I<2>>; using ABlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<ALayout, Row>, I<3>, I<2>>;
using ABlockTransferDstScalarPerVector_K1 = using ABlockTransferDstScalarPerVector_K1 =
std::conditional_t<std::is_same_v<ALayout, Row>, I<8>, I<2>>; std::conditional_t<std::is_same_v<ALayout, Row>, I<8>, I<2>>;
using ABlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<1>, I<0>>; using ABlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<1>, I<0>>;
using BBlockTransferThreadClusterArrageOrder = using BBlockTransferThreadClusterArrageOrder =
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>; std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 2, 1>, S<1, 0, 2>>;
using BBlockTransferSrcAccessOrder = using BBlockTransferSrcAccessOrder =
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>; std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 2, 1>, S<1, 0, 2>>;
using BBlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<BLayout, Row>, I<2>, I<3>>; using BBlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<BLayout, Row>, I<2>, I<3>>;
using BBlockTransferDstScalarPerVector_K1 = using BBlockTransferDstScalarPerVector_K1 =
std::conditional_t<std::is_same_v<ALayout, Row>, I<2>, I<8>>; std::conditional_t<std::is_same_v<ALayout, Row>, I<2>, I<8>>;
...@@ -214,14 +214,14 @@ struct DeviceGroupedGemmSplitkInstanceWrapper ...@@ -214,14 +214,14 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
32, 32,
4, 4,
2, 2,
S<1, 4, 16, 1>, S<4, 16, 1>,
ABlockTransferThreadClusterArrageOrder, ABlockTransferThreadClusterArrageOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim::value, ABlockTransferSrcVectorDim::value,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1::value, ABlockTransferDstScalarPerVector_K1::value,
ABlockLdsAddExtraM::value, ABlockLdsAddExtraM::value,
S<1, 4, 16, 1>, S<4, 16, 1>,
BBlockTransferThreadClusterArrageOrder, BBlockTransferThreadClusterArrageOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim::value, BBlockTransferSrcVectorDim::value,
...@@ -230,7 +230,7 @@ struct DeviceGroupedGemmSplitkInstanceWrapper ...@@ -230,7 +230,7 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
BBlockLdsAddExtraM::value, BBlockLdsAddExtraM::value,
1, 1,
1, 1,
S<1, 16, 1, 8>, S<16, 1, 8>,
CDEBlockTransferScalarPerVector_NPerBlock>; CDEBlockTransferScalarPerVector_NPerBlock>;
bool IsSupported(const std::vector<int>& Ms, bool IsSupported(const std::vector<int>& Ms,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment