Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8c7d03ec
Commit
8c7d03ec
authored
Oct 18, 2023
by
Jing Zhang
Browse files
add setElementwiseOp
parent
2a964f40
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
39 additions
and
14 deletions
+39
-14
example/15_grouped_gemm/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp
...ed_gemm/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp
+3
-2
include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp
...or_operation/gpu/device/device_grouped_gemm_multi_abd.hpp
+3
-3
include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp
...ion/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp
+4
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
...evice/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
+29
-9
No files found.
example/15_grouped_gemm/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp
View file @
8c7d03ec
...
@@ -243,8 +243,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -243,8 +243,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
std
::
vector
<
void
*>
p_Cs
=
{};
std
::
vector
<
void
*>
p_Cs
=
{};
// do GEMM
// do GEMM
auto
argument
=
gemm
.
MakeArgument
(
auto
argument
=
gemm
.
MakeArgument
(
p_As
,
p_Bs
,
p_Ds
,
p_Cs
,
gemm_descs
);
p_As
,
p_Bs
,
p_Ds
,
p_Cs
,
gemm_descs
,
a_element_op
,
b_element_op
,
cde_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
{
...
@@ -265,6 +264,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -265,6 +264,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
gemm
.
SetDeviceKernelArgs
(
argument
,
gemm_kernel_args_dev
.
GetDeviceBuffer
());
gemm
.
SetDeviceKernelArgs
(
argument
,
gemm_kernel_args_dev
.
GetDeviceBuffer
());
gemm
.
SetKBatch
(
argument
,
config
.
k_batch
);
gemm
.
SetKBatch
(
argument
,
config
.
k_batch
);
gemm
.
SetElementwiseOps
(
argument
,
a_element_op
,
b_element_op
,
cde_element_op
);
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
if
(
config
.
time_kernel
)
if
(
config
.
time_kernel
)
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp
View file @
8c7d03ec
...
@@ -50,9 +50,9 @@ struct DeviceGroupedGemmMultiABD : public BaseOperator
...
@@ -50,9 +50,9 @@ struct DeviceGroupedGemmMultiABD : public BaseOperator
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_ds
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_ds
,
std
::
vector
<
void
*>&
p_e
,
std
::
vector
<
void
*>&
p_e
,
std
::
vector
<
GemmMultiABDDesc
>&
gemm_desc
,
std
::
vector
<
GemmMultiABDDesc
>&
gemm_desc
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{}
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{}
,
CElementwiseOperation
c_element_op
)
=
0
;
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{}
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp
View file @
8c7d03ec
...
@@ -56,6 +56,10 @@ struct DeviceGroupedGemmMultiABDFixedNK : DeviceGroupedGemmMultiABD<AsLayout,
...
@@ -56,6 +56,10 @@ struct DeviceGroupedGemmMultiABDFixedNK : DeviceGroupedGemmMultiABD<AsLayout,
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
const
void
*
kernel_args
)
const
=
0
;
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
const
void
*
kernel_args
)
const
=
0
;
virtual
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
=
0
;
virtual
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
=
0
;
virtual
void
SetKBatch
(
BaseArgument
*
p_arg
,
index_t
k_batch
)
const
=
0
;
virtual
void
SetKBatch
(
BaseArgument
*
p_arg
,
index_t
k_batch
)
const
=
0
;
virtual
void
SetElementwiseOps
(
BaseArgument
*
p_arg
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
cde_element_op
)
const
=
0
;
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
View file @
8c7d03ec
...
@@ -453,9 +453,9 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
...
@@ -453,9 +453,9 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
,
std
::
vector
<
void
*>&
,
std
::
vector
<
void
*>&
,
std
::
vector
<
GemmMultiABDDesc
>&
gemm_descs
,
std
::
vector
<
GemmMultiABDDesc
>&
gemm_descs
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{}
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{}
,
CDEElementwiseOperation
c_element_op
)
CDEElementwiseOperation
c_element_op
=
CDEElementwiseOperation
{}
)
:
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
:
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
{
grid_size_
=
0
;
grid_size_
=
0
;
...
@@ -754,9 +754,9 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
...
@@ -754,9 +754,9 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_Ds
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_Ds
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmMultiABDDesc
>
gemm_descs
,
std
::
vector
<
GemmMultiABDDesc
>
gemm_descs
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{}
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{}
,
CDEElementwiseOperation
c_element_op
)
CDEElementwiseOperation
c_element_op
=
CDEElementwiseOperation
{}
)
{
{
return
Argument
{
return
Argument
{
p_As
,
p_Bs
,
p_Ds
,
p_Es
,
gemm_descs
,
a_element_op
,
b_element_op
,
c_element_op
};
p_As
,
p_Bs
,
p_Ds
,
p_Es
,
gemm_descs
,
a_element_op
,
b_element_op
,
c_element_op
};
...
@@ -771,9 +771,9 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
...
@@ -771,9 +771,9 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_Ds
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_Ds
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmMultiABDDesc
>&
gemm_descs
,
std
::
vector
<
GemmMultiABDDesc
>&
gemm_descs
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{}
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{}
,
CDEElementwiseOperation
c_element_op
)
override
CDEElementwiseOperation
c_element_op
=
CDEElementwiseOperation
{}
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
return
std
::
make_unique
<
Argument
>
(
p_As
,
p_Bs
,
p_Ds
,
p_Es
,
gemm_descs
,
a_element_op
,
b_element_op
,
c_element_op
);
p_As
,
p_Bs
,
p_Ds
,
p_Es
,
gemm_descs
,
a_element_op
,
b_element_op
,
c_element_op
);
...
@@ -814,6 +814,16 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
...
@@ -814,6 +814,16 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
return
str
.
str
();
return
str
.
str
();
}
}
static
void
SetElementwiseOps
(
Argument
&
arg
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
c_element_op
)
{
arg
.
a_element_op_
=
a_element_op
;
arg
.
b_element_op_
=
b_element_op
;
arg
.
c_element_op_
=
c_element_op
;
}
static
void
SetDeviceKernelArgs
(
Argument
&
arg
,
const
void
*
kernel_args
)
static
void
SetDeviceKernelArgs
(
Argument
&
arg
,
const
void
*
kernel_args
)
{
{
arg
.
grouped_gemm_kernel_args_dev
=
kernel_args
;
arg
.
grouped_gemm_kernel_args_dev
=
kernel_args
;
...
@@ -825,6 +835,16 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
...
@@ -825,6 +835,16 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
return
SetDeviceKernelArgs
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
kernel_args
);
return
SetDeviceKernelArgs
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
kernel_args
);
}
}
void
SetElementwiseOps
(
BaseArgument
*
p_arg
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
c_element_op
)
const
override
{
SetElementwiseOps
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
a_element_op
,
b_element_op
,
c_element_op
);
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
{
auto
arg
=
*
dynamic_cast
<
const
Argument
*>
(
p_arg
);
auto
arg
=
*
dynamic_cast
<
const
Argument
*>
(
p_arg
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment