Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
77fa9fda
Unverified
Commit
77fa9fda
authored
Nov 27, 2024
by
arai713
Committed by
GitHub
Nov 27, 2024
Browse files
Merge branch 'develop' into codegen_hiprtc
parents
760ea189
e7b62864
Changes
72
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
898 additions
and
325 deletions
+898
-325
include/ck/library/utility/host_tensor.hpp
include/ck/library/utility/host_tensor.hpp
+0
-0
include/ck/library/utility/host_tensor_generator.hpp
include/ck/library/utility/host_tensor_generator.hpp
+0
-0
include/ck/library/utility/iterator.hpp
include/ck/library/utility/iterator.hpp
+0
-0
include/ck/library/utility/literals.hpp
include/ck/library/utility/literals.hpp
+0
-0
include/ck/library/utility/numeric.hpp
include/ck/library/utility/numeric.hpp
+0
-0
include/ck/library/utility/ranges.hpp
include/ck/library/utility/ranges.hpp
+0
-0
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
...de/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
+131
-1
include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp
...sor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp
+13
-37
include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp
...tion/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp
+0
-136
include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp
...ensor_operation/gpu/device/device_grouped_gemm_splitk.hpp
+18
-2
include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp
...or_operation/gpu/device/device_grouped_gemm_tile_loop.hpp
+6
-86
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
+58
-35
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
...device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
+21
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
+19
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
...tion/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+55
-17
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+31
-4
include/ck/utility/loop_scheduler.hpp
include/ck/utility/loop_scheduler.hpp
+0
-1
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+224
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
...ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
+184
-1
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp
.../grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp
+138
-0
No files found.
library/
include/ck/library/utility/host_tensor.hpp
→
include/ck/library/utility/host_tensor.hpp
View file @
77fa9fda
File moved
library/
include/ck/library/utility/host_tensor_generator.hpp
→
include/ck/library/utility/host_tensor_generator.hpp
View file @
77fa9fda
File moved
library/
include/ck/library/utility/iterator.hpp
→
include/ck/library/utility/iterator.hpp
View file @
77fa9fda
File moved
library/
include/ck/library/utility/literals.hpp
→
include/ck/library/utility/literals.hpp
View file @
77fa9fda
File moved
library/
include/ck/library/utility/numeric.hpp
→
include/ck/library/utility/numeric.hpp
View file @
77fa9fda
File moved
library/
include/ck/library/utility/ranges.hpp
→
include/ck/library/utility/ranges.hpp
View file @
77fa9fda
File moved
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
View file @
77fa9fda
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <array>
#include <iostream>
#include <iostream>
#include <sstream>
#include <stdexcept>
#include <vector>
#include <vector>
#include "device_base.hpp"
#include "device_base.hpp"
#include "ck/utility/ignore.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
///
/// @brief Structure representing single GEMM problem arguments.
///
/// The pointer to the vector of those structures is passed to the GroupedGEMM entry
/// point kernel.
///
/// @tparam NumDTensor The number of D input tensors.
///
template
<
index_t
NumDTensor
=
0
>
struct
GroupedGemmKernelArgument
{
__host__
__device__
GroupedGemmKernelArgument
(
const
void
*
p_a_grid_
,
const
void
*
p_b_grid_
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid_
,
void
*
p_e_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs_
,
index_t
StrideE_
)
:
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_ds_grid
{
p_ds_grid_
},
p_e_grid
{
p_e_grid_
},
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideDs
{
StrideDs_
},
StrideE
{
StrideE_
}
{
}
const
void
*
p_a_grid
;
const
void
*
p_b_grid
;
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
;
void
*
p_e_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
index_t
StrideE
;
void
Print
()
const
{
std
::
stringstream
str
;
for
(
auto
sd
:
StrideDs
)
str
<<
sd
<<
","
;
std
::
cout
<<
"arg {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SE:"
<<
StrideE
<<
", "
<<
"SDs: {"
<<
str
.
str
()
<<
"}"
<<
"}"
<<
std
::
endl
;
}
};
struct
GemmDesc
struct
GemmDesc
{
{
ck
::
index_t
M_
,
N_
,
K_
;
ck
::
index_t
M_
,
N_
,
K_
;
...
@@ -48,6 +118,66 @@ struct DeviceGroupedGemm : public BaseOperator
...
@@ -48,6 +118,66 @@ struct DeviceGroupedGemm : public BaseOperator
CElementwiseOperation
c_element_op
)
=
0
;
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
//---------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer and may copy data to device.
///
/// TODO: Add which kernels are using this (TileLoop * FixedNK ??)
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which will contain kernel
/// arguments.
/// @param[in] p_host_kernel_args The pointer to the host memory which contains kernel
/// arguments that should be copied to device memory.
///
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
,
const
void
*
p_host_kernel_args
)
const
{
ignore
=
p_arg
;
ignore
=
p_dev_kernel_args
;
ignore
=
p_host_kernel_args
;
std
::
ostringstream
err
;
err
<<
"This function is not implemented by the kernel: "
<<
this
->
GetTypeString
()
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
//----------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer and may copy data to device.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
{
ignore
=
p_arg
;
ignore
=
p_dev_kernel_args
;
std
::
ostringstream
err
;
err
<<
"This function is not implemented by the kernel: "
<<
this
->
GetTypeString
()
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
//----------------------------------------------------------------------------------------------
/// @brief Gets the device kernel argument size.
///
/// @param[in] p_arg The pointer to the Device op Argument.
///
/// @return The device kernel argument size.
///
virtual
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
{
ignore
=
p_arg
;
std
::
ostringstream
err
;
err
<<
"This function is not implemented by the kernel: "
<<
this
->
GetTypeString
()
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp
View file @
77fa9fda
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <iostream>
#include "device_grouped_gemm_splitk.hpp"
#include <array>
#include "device_grouped_gemm.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
index_t
NumDTensor
=
0
>
struct
GroupedGemmKernelArgument
{
const
void
*
p_a_grid
;
const
void
*
p_b_grid
;
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
;
void
*
p_e_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
index_t
StrideE
;
};
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
DsLayout
,
typename
DsLayout
,
...
@@ -41,21 +20,18 @@ template <typename ALayout,
...
@@ -41,21 +20,18 @@ template <typename ALayout,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
>
struct
DeviceGroupedGemmFixedNK
:
DeviceGroupedGemm
<
ALayout
,
struct
DeviceGroupedGemmFixedNK
:
DeviceGroupedGemm
SplitK
<
ALayout
,
BLayout
,
BLayout
,
DsLayout
,
DsLayout
,
ELayout
,
ELayout
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
DsDataType
,
DsDataType
,
EDataType
,
EDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
CElementwiseOperation
>
{
{
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
const
void
*
kernel_args
)
const
=
0
;
virtual
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
=
0
;
virtual
void
SetKBatch
(
BaseArgument
*
p_arg
,
index_t
k_batch
)
const
=
0
;
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp
deleted
100644 → 0
View file @
760ea189
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <iostream>
#include <vector>
#include <sstream>
#include "device_grouped_gemm.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
///
/// @brief Structure representing single GEMM problem arguments.
///
/// The pointer to the vector of those structures is passed to the GroupedGEMM entry
/// point kernel.
///
/// @tparam NumDTensor The number of D input tensors.
///
template
<
index_t
NumDTensor
=
0
>
struct
GroupedGemmMultipleDKernelArguments
{
__host__
__device__
GroupedGemmMultipleDKernelArguments
(
const
void
*
p_a_grid_
,
const
void
*
p_b_grid_
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid_
,
void
*
p_e_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs_
,
index_t
StrideE_
)
:
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_ds_grid
{
p_ds_grid_
},
p_e_grid
{
p_e_grid_
},
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideDs
{
StrideDs_
},
StrideE
{
StrideE_
}
{
}
const
void
*
p_a_grid
;
const
void
*
p_b_grid
;
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
;
void
*
p_e_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
index_t
StrideE
;
void
Print
()
const
{
std
::
stringstream
str
;
for
(
auto
sd
:
StrideDs
)
str
<<
sd
<<
","
;
std
::
cout
<<
"arg {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SE:"
<<
StrideE
<<
", "
<<
"SDs: {"
<<
str
.
str
()
<<
"}"
<<
"}"
<<
std
::
endl
;
}
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
struct
DeviceGroupedGemmMultipleDSplitK
:
public
DeviceGroupedGemm
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
//----------------------------------------------------------------------------------------------
/// @brief Sets the k batch size.
///
/// @param p_arg Pointer to the Argument we're going to change.
/// @param[in] kbatch The kbatch value.
///
virtual
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
=
0
;
//----------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
=
0
;
//----------------------------------------------------------------------------------------------
/// @brief Gets the device kernel argument size.
///
/// @param[in] p_arg The pointer to the Device op Argument.
///
/// @return The device kernel argument size.
///
virtual
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp
View file @
77fa9fda
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <iostream>
#include <vector>
#include "device_grouped_gemm.hpp"
#include "device_grouped_gemm.hpp"
...
@@ -31,7 +31,23 @@ struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm<ALayout,
...
@@ -31,7 +31,23 @@ struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm<ALayout,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
CElementwiseOperation
>
{
{
//----------------------------------------------------------------------------------------------
/// @brief Sets the k batch size.
///
/// @param p_arg Pointer to the Argument we're going to change.
/// @param[in] kbatch The kbatch value.
///
virtual
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
=
0
;
virtual
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
=
0
;
//----------------------------------------------------------------------------------------------
/// @brief Sets the k batch size.
///
/// @param p_arg Pointer to the Argument we're going to change.
/// @param[in] kbatch The kbatch value.
///
virtual
void
SetKBatch
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
{
this
->
SetKBatchSize
(
p_arg
,
kbatch
);
};
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp
View file @
77fa9fda
...
@@ -3,83 +3,20 @@
...
@@ -3,83 +3,20 @@
#pragma once
#pragma once
#include <array>
#include <iostream>
#include <vector>
#include <sstream>
#include "device_grouped_gemm.hpp"
#include "device_grouped_gemm.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
/// @brief Grouped GEMM kernel using output Tile Looping algorithm
///
///
/// @brief Structure representing single GEMM problem arguments.
/// @par This kernel does not require any knowledge about input data sizes (GEMM M/N/K)
///
/// It requires only the number of groups to launch. Other information like
/// The pointer to the vector of those structures is passed to the GroupedGEMM entry
/// data pointers and GEMM sizes, packed into gemm kernel args may be all dynamic
/// point kernel.
/// (known only at kernel run-time).
///
/// @tparam NumDTensor The number of D input tensors.
///
///
template
<
index_t
NumDTensor
=
0
>
/// @note This kernel does not support SplitK.
struct
GroupedGemmTileLoopKernelArguments
{
__host__
__device__
GroupedGemmTileLoopKernelArguments
(
const
void
*
p_a_grid_
,
const
void
*
p_b_grid_
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid_
,
void
*
p_e_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs_
,
index_t
StrideE_
)
:
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_ds_grid
{
p_ds_grid_
},
p_e_grid
{
p_e_grid_
},
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideDs
{
StrideDs_
},
StrideE
{
StrideE_
}
{
}
const
void
*
p_a_grid
;
const
void
*
p_b_grid
;
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
;
void
*
p_e_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
index_t
StrideE
;
void
Print
()
const
{
std
::
stringstream
str
;
for
(
auto
sd
:
StrideDs
)
str
<<
sd
<<
","
;
std
::
cout
<<
"arg {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SE:"
<<
StrideE
<<
", "
<<
"SDs: {"
<<
str
.
str
()
<<
"}"
<<
"}"
<<
std
::
endl
;
}
};
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
...
@@ -104,23 +41,6 @@ struct DeviceGroupedGemmTileLoop : public DeviceGroupedGemm<ALayout,
...
@@ -104,23 +41,6 @@ struct DeviceGroupedGemmTileLoop : public DeviceGroupedGemm<ALayout,
BElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
CDEElementwiseOperation
>
{
{
//----------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
=
0
;
//----------------------------------------------------------------------------------------------
/// @brief Gets the device kernel argument size.
///
/// @param[in] p_arg The pointer to the Device op Argument.
///
/// @return The device kernel argument size.
///
virtual
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
=
0
;
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
View file @
77fa9fda
...
@@ -18,7 +18,6 @@
...
@@ -18,7 +18,6 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
...
@@ -78,17 +77,17 @@ template <typename ALayout,
...
@@ -78,17 +77,17 @@ template <typename ALayout,
// TODO: change gridwise_gemm_v2r4r2 to support AK1 & BK1
// TODO: change gridwise_gemm_v2r4r2 to support AK1 & BK1
enable_if_t
<
AK1
==
BK1
,
bool
>
=
false
>
enable_if_t
<
AK1
==
BK1
,
bool
>
=
false
>
struct
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
struct
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
:
public
DeviceGroupedGemm
MultipleD
SplitK
<
ALayout
,
:
public
DeviceGroupedGemmSplitK
<
ALayout
,
BLayout
,
BLayout
,
DsLayout
,
DsLayout
,
ELayout
,
ELayout
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
DsDataType
,
DsDataType
,
EDataType
,
EDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
CDEElementwiseOperation
>
{
{
using
DeviceOp
=
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
;
using
DeviceOp
=
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
;
...
@@ -530,7 +529,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -530,7 +529,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
index_t
skipped_group_count_
;
index_t
skipped_group_count_
;
index_t
grid_size_
;
index_t
grid_size_
;
// Pointer to device memory with GEMM kernel arguments.
// Pointer to device memory with GEMM kernel arguments.
const
void
*
p_dev_gemm_args_
;
void
*
p_dev_gemm_
k
args_
;
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
...
@@ -566,7 +565,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -566,7 +565,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
/// @return The average kernel execution time (if time measurement is enabled.)
/// @return The average kernel execution time (if time measurement is enabled.)
///
///
float
Run
(
const
Argument
&
arg
,
float
Run
(
const
Argument
&
arg
,
const
void
*
dev_gemm_args
,
void
*
dev_gemm_args
,
void
*
dev_gemm_workspace
,
void
*
dev_gemm_workspace
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
...
@@ -621,7 +620,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -621,7 +620,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
///
///
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
if
(
arg
.
p_dev_gemm_args_
==
nullptr
)
if
(
arg
.
p_dev_gemm_
k
args_
==
nullptr
)
{
{
std
::
ostringstream
err
;
std
::
ostringstream
err
;
err
<<
"The gemm arguments device buffer is not allocated!"
err
<<
"The gemm arguments device buffer is not allocated!"
...
@@ -637,7 +636,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -637,7 +636,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
throw
std
::
runtime_error
(
err
.
str
());
throw
std
::
runtime_error
(
err
.
str
());
}
}
return
Run
(
arg
,
arg
.
p_dev_gemm_args_
,
arg
.
p_workspace_
,
stream_config
);
return
Run
(
arg
,
arg
.
p_dev_gemm_
k
args_
,
arg
.
p_workspace_
,
stream_config
);
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
float
Run
(
const
BaseArgument
*
p_arg
,
...
@@ -723,7 +722,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -723,7 +722,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
template
<
bool
HasMainKBlockLoop
>
template
<
bool
HasMainKBlockLoop
>
float
DispatchKernel
(
const
Argument
&
arg
,
float
DispatchKernel
(
const
Argument
&
arg
,
const
void
*
dev_gemm_args
,
void
*
dev_gemm_
k
args
,
void
*
dev_gemm_workspace
,
void
*
dev_gemm_workspace
,
const
StreamConfig
&
stream_config
)
const
const
StreamConfig
&
stream_config
)
const
{
{
...
@@ -746,7 +745,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -746,7 +745,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
return
LaunchKernel
(
gemm_kernel
,
return
LaunchKernel
(
gemm_kernel
,
elementwise_kernel
,
elementwise_kernel
,
arg
,
arg
,
dev_gemm_args
,
dev_gemm_
k
args
,
dev_gemm_workspace
,
dev_gemm_workspace
,
stream_config
);
stream_config
);
}
}
...
@@ -755,12 +754,19 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -755,12 +754,19 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
float
LaunchKernel
(
const
KernelFunction
&
gemm_kernel
,
float
LaunchKernel
(
const
KernelFunction
&
gemm_kernel
,
const
KernelFunction2
&
elementwise_kernel
,
const
KernelFunction2
&
elementwise_kernel
,
const
Argument
&
arg
,
const
Argument
&
arg
,
const
void
*
dev_gemm_args
,
void
*
dev_gemm_
k
args
,
[[
maybe_unused
]]
void
*
dev_gemm_workspace
,
[[
maybe_unused
]]
void
*
dev_gemm_workspace
,
const
StreamConfig
&
stream_config
)
const
const
StreamConfig
&
stream_config
)
const
{
{
float
time
{
0.
f
};
float
time
{
0.
f
};
hip_check_error
(
hipMemcpyWithStream
(
dev_gemm_kargs
,
arg
.
gemm_kernel_args_
.
data
(),
arg
.
gemm_kernel_args_
.
size
()
*
sizeof
(
GemmTransKernelArg
),
hipMemcpyHostToDevice
,
stream_config
.
stream_id_
));
auto
preprocess
=
[
&
]()
{
auto
preprocess
=
[
&
]()
{
hip_check_error
(
hipMemsetAsync
(
hip_check_error
(
hipMemsetAsync
(
dev_gemm_workspace
,
0
,
arg
.
GetWorkspaceSizeBytes
(),
stream_config
.
stream_id_
));
dev_gemm_workspace
,
0
,
arg
.
GetWorkspaceSizeBytes
(),
stream_config
.
stream_id_
));
...
@@ -774,7 +780,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -774,7 +780,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
dim3
(
arg
.
grid_size_
),
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
cast_pointer_to_constant_address_space
(
dev_gemm_args
),
cast_pointer_to_constant_address_space
(
dev_gemm_
k
args
),
arg
.
gemm_kernel_args_
.
size
(),
arg
.
gemm_kernel_args_
.
size
(),
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
...
@@ -930,18 +936,30 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -930,18 +936,30 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
return
str
.
str
();
return
str
.
str
();
}
}
void
SetDeviceKernelArgs
(
Argument
&
arg
,
void
*
p_dev_kernel_args
)
const
void
SetDeviceKernelArgs
(
Base
Argument
*
p_
arg
,
void
*
p_dev_kernel_args
)
const
override
{
{
arg
.
p_dev_gemm_args_
=
p_dev_kernel_args
;
auto
arg_ptr
=
dynamic_cast
<
Argument
*>
(
p_arg
);
hip_check_error
(
hipMemcpy
(
p_dev_kernel_args
,
if
(
arg_ptr
)
arg
.
gemm_kernel_args_
.
data
(),
{
GetDeviceKernelArgSize
(
&
arg
),
arg_ptr
->
p_dev_gemm_kargs_
=
p_dev_kernel_args
;
hipMemcpyHostToDevice
));
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"
);
}
}
void
S
etDeviceKernelArg
s
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
override
size_t
G
etDeviceKernelArg
Size
(
const
BaseArgument
*
p_arg
)
const
override
{
{
return
SetDeviceKernelArgs
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
p_dev_kernel_args
);
auto
arg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
arg
)
{
return
arg
->
gemm_kernel_args_
.
size
()
*
sizeof
(
GemmTransKernelArg
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"
);
}
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
...
@@ -974,17 +992,22 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -974,17 +992,22 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"
);
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"
);
}
}
static
void
SetKBatchSize
(
Argument
&
arg
,
index_t
kbatch
)
{
arg
.
UpdateKBatch
(
kbatch
);
}
[[
deprecated
]]
static
void
SetKBatchSize
(
Argument
&
arg
,
index_t
kbatch
)
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
override
{
{
return
SetKBatchSize
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
kbatch
);
arg
.
UpdateKBatch
(
kbatch
);
}
}
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
override
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
override
{
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
gemm_kernel_args_
.
size
()
*
auto
p_arg_
=
dynamic_cast
<
Argument
*>
(
p_arg
);
sizeof
(
GemmTransKernelArg
);
if
(
p_arg_
)
{
p_arg_
->
UpdateKBatch
(
kbatch
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"
);
}
}
};
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
View file @
77fa9fda
...
@@ -20,7 +20,6 @@
...
@@ -20,7 +20,6 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" // stare wywalic
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" // stare wywalic
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -522,7 +521,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
...
@@ -522,7 +521,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
ComputeTypeA
,
ComputeTypeA
,
ComputeTypeB
>
;
ComputeTypeB
>
;
using
KernelArguments
=
GroupedGemm
TileLoop
KernelArgument
s
<
NumDTensor
>
;
using
KernelArguments
=
GroupedGemmKernelArgument
<
NumDTensor
>
;
using
Block2ETileMap
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
using
Block2ETileMap
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
using
OffsettedLocalBlock2ETileMap
=
OffsettedBlockToCTileMap2
<
Block2ETileMap
>
;
using
OffsettedLocalBlock2ETileMap
=
OffsettedBlockToCTileMap2
<
Block2ETileMap
>
;
...
@@ -936,12 +935,31 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
...
@@ -936,12 +935,31 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
return
str
.
str
();
return
str
.
str
();
}
}
void
SetDeviceKernelArgs
(
Argument
&
arg
,
void
*
p_dev_kernel_args
,
const
void
*
p_host_kernel_args
)
const
{
arg
.
p_dev_gemm_args_
=
p_dev_kernel_args
;
hip_check_error
(
hipMemcpy
(
p_dev_kernel_args
,
p_host_kernel_args
,
GetDeviceKernelArgSize
(
&
arg
),
hipMemcpyHostToDevice
));
}
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
,
const
void
*
p_host_kernel_args
)
const
override
{
return
SetDeviceKernelArgs
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
p_dev_kernel_args
,
p_host_kernel_args
);
}
void
SetDeviceKernelArgs
(
Argument
&
arg
,
void
*
p_dev_kernel_args
)
const
void
SetDeviceKernelArgs
(
Argument
&
arg
,
void
*
p_dev_kernel_args
)
const
{
{
arg
.
p_dev_gemm_args_
=
p_dev_kernel_args
;
arg
.
p_dev_gemm_args_
=
p_dev_kernel_args
;
}
}
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
override
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
override
{
{
return
SetDeviceKernelArgs
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
p_dev_kernel_args
);
return
SetDeviceKernelArgs
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
p_dev_kernel_args
);
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
View file @
77fa9fda
#pragma once
#pragma once
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -717,7 +717,24 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
...
@@ -717,7 +717,24 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
group_count_
*
sizeof
(
GemmBiasTransKernelArg
);
auto
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
p_arg_
)
{
return
p_arg_
->
group_count_
*
sizeof
(
GemmBiasTransKernelArg
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDXdlCShuffle::Argument structure!"
);
}
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
override
{
return
GetWorkSpaceSize
(
p_arg
);
}
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
override
{
return
this
->
SetWorkSpacePointer
(
p_arg
,
p_dev_kernel_args
);
}
}
};
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
View file @
77fa9fda
...
@@ -445,6 +445,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -445,6 +445,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
using
Block2ETileMap
=
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
<
MPerBlock
,
NPerBlock
>
;
using
Block2ETileMap
=
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
<
MPerBlock
,
NPerBlock
>
;
using
GroupedGemmBlock2ETileMap
=
OffsettedBlockToCTileMapMLoops
<
Block2ETileMap
>
;
using
GroupedGemmBlock2ETileMap
=
OffsettedBlockToCTileMapMLoops
<
Block2ETileMap
>
;
// TODO: replace with GroupedGemmKernelArgument
struct
GemmBiasTransKernelArg
struct
GemmBiasTransKernelArg
{
{
// pointers
// pointers
...
@@ -900,40 +901,58 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -900,40 +901,58 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
return
str
.
str
();
return
str
.
str
();
}
}
static
void
SetDeviceKernelArgs
(
Argument
&
arg
,
const
void
*
kernel_args
)
{
arg
.
grouped_gemm_kernel_args_dev
=
kernel_args
;
}
// polymorphic
// polymorphic
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
const
void
*
kernel_args
)
const
override
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
kernel_args
)
const
override
{
{
return
SetDeviceKernelArgs
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
kernel_args
);
auto
arg_ptr
=
dynamic_cast
<
Argument
*>
(
p_arg
);
if
(
arg_ptr
)
{
arg_ptr
->
grouped_gemm_kernel_args_dev
=
kernel_args
;
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"
);
}
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
{
auto
arg
=
*
dynamic_cast
<
const
Argument
*>
(
p_arg
);
auto
arg_ptr
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
arg_ptr
)
return
arg
.
group_count_
*
arg
.
barrier_size_grp_
*
sizeof
(
uint32_t
);
{
return
arg_ptr
->
group_count_
*
arg_ptr
->
barrier_size_grp_
*
sizeof
(
uint32_t
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"
);
}
}
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
override
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
override
{
{
auto
arg
=
*
dynamic_cast
<
const
Argument
*>
(
p_arg
);
auto
arg_ptr
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
arg_ptr
)
return
arg
.
group_count_
*
sizeof
(
GroupedGemmKernelArgument
<
NumDTensor
>
);
{
return
arg_ptr
->
group_count_
*
sizeof
(
GroupedGemmKernelArgument
<
NumDTensor
>
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"
);
}
}
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
p_workspace
,
void
*
p_workspace
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
const
override
const
StreamConfig
&
stream_config
=
StreamConfig
{})
const
override
{
{
auto
p_arg_
=
dynamic_cast
<
Argument
*>
(
p_arg
);
auto
arg_ptr
=
dynamic_cast
<
Argument
*>
(
p_arg
);
p_arg_
->
p_workspace_
=
p_workspace
;
if
(
arg_ptr
)
{
arg_ptr
->
p_workspace_
=
p_workspace
;
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"
);
hip_check_error
(
hip_check_error
(
hipMemsetAsync
(
p_workspace
,
0
,
GetWorkSpaceSize
(
p_
arg
),
stream_config
.
stream_id_
));
hipMemsetAsync
(
p_workspace
,
0
,
GetWorkSpaceSize
(
arg
_ptr
),
stream_config
.
stream_id_
));
}
}
static
void
SetKBatch
(
Argument
&
arg
,
index_t
k_batch
)
{
arg
.
UpdateKBatch
(
k_batch
);
}
static
void
SetKBatch
(
Argument
&
arg
,
index_t
k_batch
)
{
arg
.
UpdateKBatch
(
k_batch
);
}
...
@@ -941,7 +960,26 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -941,7 +960,26 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
// polymorphic
// polymorphic
void
SetKBatch
(
BaseArgument
*
p_arg
,
index_t
k_batch
)
const
override
void
SetKBatch
(
BaseArgument
*
p_arg
,
index_t
k_batch
)
const
override
{
{
return
SetKBatch
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
k_batch
);
auto
arg_ptr
=
dynamic_cast
<
Argument
*>
(
p_arg
);
if
(
arg_ptr
)
{
arg_ptr
->
UpdateKBatch
(
k_batch
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"
);
}
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
override
{
auto
arg_ptr
=
dynamic_cast
<
Argument
*>
(
p_arg
);
if
(
arg_ptr
)
{
arg_ptr
->
UpdateKBatch
(
kbatch
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"
);
}
}
};
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
View file @
77fa9fda
...
@@ -546,7 +546,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -546,7 +546,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
bool
supported
=
true
;
bool
supported
=
true
;
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
{
{
const
auto
&
a
=
arg
.
gemm_kernel_args_
[
i
].
karg_
;
const
auto
&
a
=
arg
.
gemm_kernel_args_
[
i
].
karg_
;
bool
group_arg_valid
=
GridwiseGemm
::
CheckValidity
(
a
);
bool
group_arg_valid
=
GridwiseGemm
::
CheckValidity
(
a
);
if
(
not
group_arg_valid
)
if
(
not
group_arg_valid
)
{
{
...
@@ -636,16 +637,42 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -636,16 +637,42 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
gemm_kernel_args_
.
size
()
*
auto
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
sizeof
(
GemmTransKernelArg
);
if
(
p_arg_
)
{
return
p_arg_
->
gemm_kernel_args_
.
size
()
*
sizeof
(
GemmTransKernelArg
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!"
);
}
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
override
{
return
GetWorkSpaceSize
(
p_arg
);
}
}
// TODO: deperecation notice.
static
void
SetKBatchSize
(
Argument
&
arg
,
index_t
kbatch
)
{
arg
.
UpdateKBatch
(
kbatch
);
}
static
void
SetKBatchSize
(
Argument
&
arg
,
index_t
kbatch
)
{
arg
.
UpdateKBatch
(
kbatch
);
}
// polymorphic
// polymorphic
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
override
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
override
{
{
return
SetKBatchSize
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
kbatch
);
auto
p_arg_
=
dynamic_cast
<
Argument
*>
(
p_arg
);
if
(
p_arg_
)
{
p_arg_
->
UpdateKBatch
(
kbatch
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!"
);
}
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
override
{
return
this
->
SetWorkSpacePointer
(
p_arg
,
p_dev_kernel_args
);
}
}
};
};
...
...
include/ck/utility/loop_scheduler.hpp
View file @
77fa9fda
...
@@ -8,7 +8,6 @@
...
@@ -8,7 +8,6 @@
#pragma once
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace
ck
{
namespace
ck
{
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
View file @
77fa9fda
...
@@ -322,6 +322,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -322,6 +322,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_sync_lds
();
block_sync_lds
();
LocalPrefill
(
a_copy_lds_window
,
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
a_element_func
);
...
@@ -374,6 +375,229 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -374,6 +375,229 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
}
}
};
};
template
<
>
struct
PipelineImpl
<
GemmPipelineScheduler
::
Interwave
>
{
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
>
CK_TILE_DEVICE
void
GlobalPrefetch
(
DstBlockTile
&
dst_block_tile
,
SrcTileWindow
&
dram_tile_window
)
const
{
load_tile
(
dst_block_tile
,
dram_tile_window
);
move_tile_window
(
dram_tile_window
,
{
0
,
KPerBlock
});
}
template
<
typename
DstTileWindow
,
typename
SrcBlockTile
,
typename
ElementFunction
>
CK_TILE_DEVICE
void
LocalPrefill
(
DstTileWindow
&
lds_tile_window
,
const
SrcBlockTile
&
src_block_tile
,
const
ElementFunction
&
element_func
)
const
{
const
auto
block_tile_tmp
=
tile_elementwise_in
(
element_func
,
src_block_tile
);
store_tile
(
lds_tile_window
,
block_tile_tmp
);
}
template
<
bool
HasHotLoop
,
TailNumber
TailNum
,
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cvref_t
<
typename
ADramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cvref_t
<
typename
BDramBlockWindowTmp
::
DataType
>>
,
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// A tile in LDS
ADataType
*
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
constexpr
auto
a_lds_block_desc
=
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>();
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
// TODO: LDS alignment should come from Policy!
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
)
*
16
;
// B tile in LDS
BDataType
*
p_b_lds
=
static_cast
<
BDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
));
constexpr
auto
b_lds_block_desc
=
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
auto
b_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds
,
b_lds_block_desc
);
// A DRAM tile window for load
auto
a_copy_dram_window
=
make_tile_window
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
a_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
auto
a_copy_lds_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
a_copy_dram_window
.
get_tile_distribution
());
// B DRAM tile window for load
auto
b_copy_dram_window
=
make_tile_window
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
b_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// B LDS tile window for store
auto
b_copy_lds_window
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
b_copy_dram_window
.
get_tile_distribution
());
// A LDS tile for block GEMM
auto
a_lds_gemm_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
// B LDS tile for block GEMM
auto
b_lds_gemm_window
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
// Block GEMM
auto
block_gemm
=
BlockGemm
();
auto
c_block_tile
=
block_gemm
.
MakeCBlockTile
();
using
ABlockTileDistr
=
decltype
(
a_copy_dram_window
.
get_tile_distribution
());
using
BBlockTileDistr
=
decltype
(
b_copy_dram_window
.
get_tile_distribution
());
using
ABlockTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ABlockTileDistr
{}));
using
BBlockTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BBlockTileDistr
{}));
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
// Global prefetch [1, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
});
// main body
if
constexpr
(
HasHotLoop
)
{
index_t
i
=
0
;
do
{
static_for
<
0
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
block_sync_lds
();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
// no second block_sync_lds because it's interwave
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
});
i
+=
PrefetchStages
;
}
while
(
i
<
(
num_loop
-
PrefetchStages
));
}
auto
HotLoopTail
=
[
&
](
auto
tail_num
)
{
static_for
<
1
,
tail_num
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
block_sync_lds
();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
// no second block_sync_lds because it's interwave
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
});
block_sync_lds
();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
};
if
constexpr
(
TailNum
==
TailNumber
::
One
)
{
block_sync_lds
();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Two
)
{
HotLoopTail
(
number
<
2
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Three
)
{
HotLoopTail
(
number
<
3
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Four
)
{
HotLoopTail
(
number
<
4
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Five
)
{
HotLoopTail
(
number
<
5
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Six
)
{
HotLoopTail
(
number
<
6
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Seven
)
{
HotLoopTail
(
number
<
7
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
HotLoopTail
(
number
<
PrefetchStages
>
{});
}
return
c_block_tile
;
}
};
template
<
typename
ADramBlockWindowTmp
,
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
AElementFunction
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
View file @
77fa9fda
...
@@ -95,6 +95,45 @@ void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
...
@@ -95,6 +95,45 @@ void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances
(
void
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Col
,
Col
,
...
@@ -189,6 +228,124 @@ void add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_bf16_bf16_mk_nk_mn_in
...
@@ -189,6 +228,124 @@ void add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_bf16_bf16_mk_nk_mn_in
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Col
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Col
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Col
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Col
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Col
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Col
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#endif
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
...
@@ -262,7 +419,11 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -262,7 +419,11 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
{
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
(
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1_inter_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv2_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances
(
add_device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
op_ptrs
);
...
@@ -334,12 +495,34 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -334,12 +495,34 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
{
add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_bf16_bf16_mk_kn_mn_instances
(
add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_bf16_bf16_mk_kn_mn_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1_inter_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv2_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
is_same_v
<
ELayout
,
Row
>
)
{
{
add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_bf16_bf16_mk_nk_mn_instances
(
add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_bf16_bf16_mk_nk_mn_instances
(
op_ptrs
);
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1_inter_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv2_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1_inter_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv2_instances
(
op_ptrs
);
}
}
}
}
#endif
#endif
...
...
library/
src
/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_
f16_f16_f16_mk_kn_mn_irregular_
instance.
c
pp
→
library/
include/ck/library
/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.
h
pp
View file @
77fa9fda
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment