"vscode:/vscode.git/clone" did not exist on "9904929b41ced2638e3eb023bdf9b55c29157b67"
Commit 66c70dfe authored by Adam Osewski's avatar Adam Osewski
Browse files

Hide unused tparams from device op and copy kernel args directly when setting pointer

parent 1a1fd0b3
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -52,11 +52,11 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial ...@@ -52,11 +52,11 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffle
// clang-format off // clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| AThreadTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BThreadTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcReset| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcReset| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| CoordinateAfter| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| CoordinateAfter| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Run| | | | | | | | Run| | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, false, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, false, 1, 1, 1, S<1, 32, 1, 8>, 4>; < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
// clang-format on // clang-format on
struct ProblemSize final struct ProblemSize final
...@@ -226,10 +226,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -226,10 +226,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, c_element_op); p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(),
gemm.GetDeviceKernelArgSize(&argument),
hipMemcpyHostToDevice));
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -119,7 +119,7 @@ struct DeviceGroupedGemmMultipleDSplitK : public DeviceGroupedGemm<ALayout, ...@@ -119,7 +119,7 @@ struct DeviceGroupedGemmMultipleDSplitK : public DeviceGroupedGemm<ALayout,
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel /// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments. /// arguments.
/// ///
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* p_dev_kernel_args) const = 0; virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const = 0;
//---------------------------------------------------------------------------------------------- //----------------------------------------------------------------------------------------------
/// @brief Gets the device kernel argument size. /// @brief Gets the device kernel argument size.
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -254,7 +254,6 @@ template <typename ALayout, ...@@ -254,7 +254,6 @@ template <typename ALayout,
index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1, index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM, index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, typename BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferThreadClusterArrangeOrder,
...@@ -262,14 +261,13 @@ template <typename ALayout, ...@@ -262,14 +261,13 @@ template <typename ALayout,
index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1, index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN, index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock, index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1, PipelineVersion PipelineVer = PipelineVersion::v1,
LoopScheduler LoopSched = make_default_loop_scheduler(),
typename ComputeDataType = EDataType> typename ComputeDataType = EDataType>
struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
: public DeviceGroupedGemmMultipleDSplitK<ALayout, : public DeviceGroupedGemmMultipleDSplitK<ALayout,
...@@ -327,7 +325,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -327,7 +325,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1, ABlockTransferDstScalarPerVector_AK1,
AThreadTransferSrcResetCoordinateAfterRun, false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsExtraM, ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
...@@ -335,7 +333,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -335,7 +333,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1, BBlockTransferDstScalarPerVector_BK1,
BThreadTransferSrcResetCoordinateAfterRun, false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsExtraN, BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
...@@ -965,12 +963,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -965,12 +963,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
return str.str(); return str.str();
} }
static void SetDeviceKernelArgs(Argument& arg, const void* p_dev_kernel_args) void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const
{ {
arg.p_dev_gemm_args_ = p_dev_kernel_args; arg.p_dev_gemm_args_ = p_dev_kernel_args;
hip_check_error(hipMemcpy(p_dev_kernel_args,
arg.gemm_kernel_args_.data(),
GetDeviceKernelArgSize(&arg),
hipMemcpyHostToDevice));
} }
void SetDeviceKernelArgs(BaseArgument* p_arg, const void* p_dev_kernel_args) const override void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
{ {
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args); return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args);
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment