Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
489599ba
Commit
489599ba
authored
Apr 20, 2024
by
Jing Zhang
Committed by
root
Apr 20, 2024
Browse files
add multiD support into gridwise and deviceOp
parent
ad1597c4
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
421 additions
and
12 deletions
+421
-12
example/01_gemm/gemm_xdl_fp16_v3.cpp
example/01_gemm/gemm_xdl_fp16_v3.cpp
+1
-1
example/01_gemm/run_gemm_example_v2.inc
example/01_gemm/run_gemm_example_v2.inc
+3
-0
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
+5
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
...operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
+36
-8
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+376
-3
No files found.
example/01_gemm/gemm_xdl_fp16_v3.cpp
View file @
489599ba
...
@@ -25,7 +25,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
...
@@ -25,7 +25,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using
DeviceGemmV2Instance
=
using
DeviceGemmV2Instance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffleV3
<
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffleV3
<
ALayout
,
BLayout
,
CLayout
,
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
ADataType
,
BDataType
,
ck
::
Tuple
<>
,
CDataType
,
AccDataType
,
CShuffleDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
256
,
224
,
256
,
224
,
256
,
...
...
example/01_gemm/run_gemm_example_v2.inc
View file @
489599ba
...
@@ -133,10 +133,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -133,10 +133,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
#ifdef BUILD_INT4_EXAMPLE
#ifdef BUILD_INT4_EXAMPLE
static_cast
<
KernelADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelBDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelBDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
{},
static_cast
<
KernelCDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelCDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
#else
#else
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
{},
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
#endif
#endif
M
,
M
,
...
@@ -144,6 +146,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -144,6 +146,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
K
,
K
,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
{},
StrideC
,
StrideC
,
KBatch
,
KBatch
,
a_element_op
,
a_element_op
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
View file @
489599ba
...
@@ -14,21 +14,26 @@ template <typename ALayout,
...
@@ -14,21 +14,26 @@ template <typename ALayout,
typename
CLayout
,
typename
CLayout
,
typename
ADataType
,
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
DsDataType
,
typename
CDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
>
struct
DeviceGemmV2
:
public
BaseOperator
struct
DeviceGemmV2
:
public
BaseOperator
{
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_c
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
ck
::
index_t
StrideC
,
ck
::
index_t
StrideC
,
ck
::
index_t
KSplit
,
ck
::
index_t
KSplit
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
View file @
489599ba
...
@@ -25,6 +25,7 @@ template <typename ALayout,
...
@@ -25,6 +25,7 @@ template <typename ALayout,
typename
CLayout
,
typename
CLayout
,
typename
ADataType
,
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
DsDataType
,
typename
CDataType
,
typename
CDataType
,
typename
GemmAccDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
CShuffleDataType
,
...
@@ -69,11 +70,14 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
...
@@ -69,11 +70,14 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
CLayout
,
CLayout
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
DsDataType
,
CDataType
,
CDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
CElementwiseOperation
>
{
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_xdl_cshuffle_v3
<
using
GridwiseGemm
=
GridwiseGemm_xdl_cshuffle_v3
<
ALayout
,
ALayout
,
...
@@ -83,6 +87,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
...
@@ -83,6 +87,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
BDataType
,
BDataType
,
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
Tuple
<>
,
CDataType
,
CDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
...
@@ -586,19 +591,35 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
...
@@ -586,19 +591,35 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
const
BDataType
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
CDataType
*
p_c
,
CDataType
*
p_c
,
index_t
M
,
index_t
M
,
index_t
N
,
index_t
N
,
index_t
K
,
index_t
K
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideC
,
index_t
StrideC
,
index_t
KBatch
,
index_t
KBatch
,
AElementwiseOperation
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
)
CElementwiseOperation
c_element_op
)
{
{
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
KBatch
};
return
Argument
{
p_a
,
p_b
,
p_ds
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideC
,
KBatch
,
a_element_op
,
b_element_op
,
c_element_op
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
@@ -606,28 +627,35 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
...
@@ -606,28 +627,35 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// polymorphic
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_c
,
void
*
p_c
,
index_t
M
,
index_t
M
,
index_t
N
,
index_t
N
,
index_t
K
,
index_t
K
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideC
,
index_t
StrideC
,
index_t
KBatch
,
index_t
KBatch
,
AElementwiseOperation
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
)
override
CElementwiseOperation
c_element_op
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b
),
p_ds
,
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
CDataType
*>
(
p_c
),
M
,
M
,
N
,
N
,
K
,
K
,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideDs
,
StrideC
,
StrideC
,
KBatch
);
KBatch
,
a_element_op
,
b_element_op
,
c_element_op
);
}
}
// polymorphic
// polymorphic
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
489599ba
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment