Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
66c70dfe
Commit
66c70dfe
authored
Jan 10, 2024
by
Adam Osewski
Browse files
Hide unused tparams from device op and copy kernel args directly when setting pointer
parent
1a1fd0b3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
19 additions
and
21 deletions
+19
-21
example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp
..._grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp
+6
-10
include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp
...tion/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
+10
-8
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
.../grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
+1
-1
No files found.
example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp
View file @
66c70dfe
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
...
...
@@ -52,11 +52,11 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmMultipleDSplitKXdlCShuffle
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer|
AThreadTransfer|
ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer|
BThreadTransfer|
BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar|
SrcReset|
AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar|
SrcReset|
AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1|
CoordinateAfter|
| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1|
CoordinateAfter|
| PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
Run|
| | | | | | |
Run|
| | | | |
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
false
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
false
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
;
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
;
// clang-format on
struct
ProblemSize
final
...
...
@@ -226,10 +226,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
p_As
,
p_Bs
,
p_Ds
,
p_Cs
,
gemm_descs
,
a_element_op
,
b_element_op
,
c_element_op
);
DeviceMem
gemm_arg_dev_mem
(
gemm
.
GetDeviceKernelArgSize
(
&
argument
));
hip_check_error
(
hipMemcpy
(
gemm_arg_dev_mem
.
GetDeviceBuffer
(),
grouped_gemm_kernel_args_
.
data
(),
gemm
.
GetDeviceKernelArgSize
(
&
argument
),
hipMemcpyHostToDevice
));
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp
View file @
66c70dfe
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -119,7 +119,7 @@ struct DeviceGroupedGemmMultipleDSplitK : public DeviceGroupedGemm<ALayout,
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
const
void
*
p_dev_kernel_args
)
const
=
0
;
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
=
0
;
//----------------------------------------------------------------------------------------------
/// @brief Gets the device kernel argument size.
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
View file @
66c70dfe
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -254,7 +254,6 @@ template <typename ALayout,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
...
...
@@ -262,14 +261,13 @@ template <typename ALayout,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
typename
ComputeDataType
=
EDataType
>
struct
DeviceGroupedGemmMultipleDSplitKXdlCShuffle
:
public
DeviceGroupedGemmMultipleDSplitK
<
ALayout
,
...
...
@@ -327,7 +325,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
AThreadTransferSrcResetCoordinateAfterRun
,
false
,
//
AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
...
...
@@ -335,7 +333,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
BThreadTransferSrcResetCoordinateAfterRun
,
false
,
//
BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
...
...
@@ -965,12 +963,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
return
str
.
str
();
}
static
void
SetDeviceKernelArgs
(
Argument
&
arg
,
const
void
*
p_dev_kernel_args
)
void
SetDeviceKernelArgs
(
Argument
&
arg
,
void
*
p_dev_kernel_args
)
const
{
arg
.
p_dev_gemm_args_
=
p_dev_kernel_args
;
hip_check_error
(
hipMemcpy
(
p_dev_kernel_args
,
arg
.
gemm_kernel_args_
.
data
(),
GetDeviceKernelArgSize
(
&
arg
),
hipMemcpyHostToDevice
));
}
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
const
void
*
p_dev_kernel_args
)
const
override
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
override
{
return
SetDeviceKernelArgs
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
p_dev_kernel_args
);
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
View file @
66c70dfe
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment