Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
613dcc6b
Commit
613dcc6b
authored
May 04, 2023
by
Po-Yen, Chen
Browse files
Remove elementwise-op objects from interfaces
parent
9fdc3fc8
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
65 deletions
+12
-65
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+1
-4
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
...or_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
+7
-45
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+4
-16
No files found.
example/01_gemm/run_gemm_example.inc
View file @
613dcc6b
...
@@ -89,10 +89,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -89,10 +89,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
K
,
K
,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
);
a_element_op
,
b_element_op
,
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
View file @
613dcc6b
...
@@ -356,10 +356,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -356,10 +356,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t
K
,
index_t
K
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC
)
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
...
@@ -385,9 +382,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -385,9 +382,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
N
,
N
,
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateNPadded
(
N
),
StrideC
)},
StrideC
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
},
kraw_
{
K
}
kraw_
{
K
}
{
{
}
}
...
@@ -402,9 +396,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -402,9 +396,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
index_t
kraw_
;
index_t
kraw_
;
};
};
...
@@ -451,9 +442,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -451,9 +442,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
CGridDesc_M_N
,
DeviceOp
::
CGridDesc_M_N
,
...
@@ -467,9 +455,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -467,9 +455,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
karg
.
p_a_grid_
,
karg
.
p_a_grid_
,
karg
.
p_b_grid_
,
karg
.
p_b_grid_
,
karg
.
p_c_grid_
,
karg
.
p_c_grid_
,
karg
.
a_element_op_
,
karg
.
b_element_op_
,
karg
.
c_element_op_
,
karg
.
a_grid_desc_ak0_m_ak1_
,
karg
.
a_grid_desc_ak0_m_ak1_
,
karg
.
b_grid_desc_bk0_n_bk1_
,
karg
.
b_grid_desc_bk0_n_bk1_
,
karg
.
c_grid_desc_m_n_
);
karg
.
c_grid_desc_m_n_
);
...
@@ -480,9 +465,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -480,9 +465,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
CGridDesc_M_N
,
DeviceOp
::
CGridDesc_M_N
,
...
@@ -495,9 +477,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -495,9 +477,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
karg
.
p_a_grid_
,
karg
.
p_a_grid_
,
karg
.
p_b_grid_
,
karg
.
p_b_grid_
,
karg
.
p_c_grid_
,
karg
.
p_c_grid_
,
karg
.
a_element_op_
,
karg
.
b_element_op_
,
karg
.
c_element_op_
,
karg
.
a_grid_desc_ak0_m_ak1_
,
karg
.
a_grid_desc_ak0_m_ak1_
,
karg
.
b_grid_desc_bk0_n_bk1_
,
karg
.
b_grid_desc_bk0_n_bk1_
,
karg
.
c_grid_desc_m_n_
);
karg
.
c_grid_desc_m_n_
);
...
@@ -554,23 +533,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -554,23 +533,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t
K
,
index_t
K
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC
)
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
{
return
Argument
{
p_a
,
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
};
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
@@ -585,9 +550,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -585,9 +550,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
,
CElementwiseOperation
c_element_op
)
override
CElementwiseOperation
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b
),
...
@@ -597,10 +562,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -597,10 +562,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
K
,
K
,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
);
a_element_op
,
b_element_op
,
c_element_op
);
}
}
// polymorphic
// polymorphic
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
613dcc6b
...
@@ -20,9 +20,6 @@ namespace ck {
...
@@ -20,9 +20,6 @@ namespace ck {
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatC
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
typename
CGridDesc_M_N
,
...
@@ -34,9 +31,6 @@ __global__ void
...
@@ -34,9 +31,6 @@ __global__ void
kernel_gemm_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
kernel_gemm_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_M_N
c_grid_desc_m_n
)
const
CGridDesc_M_N
c_grid_desc_m_n
)
...
@@ -48,9 +42,6 @@ __global__ void
...
@@ -48,9 +42,6 @@ __global__ void
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_shared
,
p_shared
,
a_element_op
,
b_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m_n
);
c_grid_desc_m_n
);
...
@@ -58,9 +49,6 @@ __global__ void
...
@@ -58,9 +49,6 @@ __global__ void
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
p_c_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_m_n
;
ignore
=
c_grid_desc_m_n
;
...
@@ -339,9 +327,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -339,9 +327,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_M_N
c_grid_desc_m_n
)
const
CGridDesc_M_N
c_grid_desc_m_n
)
...
@@ -356,8 +341,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -356,8 +341,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// divide block work by [M, N]
const
AElementwiseOperation
a_element_op
{};
const
BElementwiseOperation
b_element_op
{};
const
CElementwiseOperation
c_element_op
{};
// divide block work by [M, N]
const
auto
block_2_ctile_map
=
MakeBlock2CTileMap
(
c_grid_desc_m_n
);
const
auto
block_2_ctile_map
=
MakeBlock2CTileMap
(
c_grid_desc_m_n
);
const
auto
block_work_idx
=
const
auto
block_work_idx
=
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment