Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
86699d01
Commit
86699d01
authored
Jul 21, 2023
by
Jing Zhang
Browse files
add instances for fp32 output
parent
0c3cfcf8
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
36 additions
and
14 deletions
+36
-14
client_example/20_grouped_gemm_bias/CMakeLists.txt
client_example/20_grouped_gemm_bias/CMakeLists.txt
+3
-0
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+7
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_bias.hpp
...brary/tensor_operation_instance/gpu/grouped_gemm_bias.hpp
+23
-12
library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt
...r_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt
+3
-2
No files found.
client_example/20_grouped_gemm_bias/CMakeLists.txt
View file @
86699d01
add_executable
(
client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp
)
add_executable
(
client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp
)
target_link_libraries
(
client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_operations
)
add_executable
(
client_grouped_gemm_fixed_nk_bias_fp16_fp32_out grouped_gemm_fixed_nk_bias_fp16_fp32_out.cpp
)
target_link_libraries
(
client_grouped_gemm_fixed_nk_bias_fp16_fp32_out PRIVATE composable_kernel::device_operations
)
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
86699d01
...
@@ -139,6 +139,13 @@ struct AddBias
...
@@ -139,6 +139,13 @@ struct AddBias
{
{
e
=
c
+
d0
;
e
=
c
+
d0
;
}
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
,
float
>
(
float
&
e
,
const
float
&
c
,
const
float
&
d0
)
const
{
e
=
c
+
d0
;
}
};
};
struct
UnaryConvert
struct
UnaryConvert
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_bias.hpp
View file @
86699d01
...
@@ -16,6 +16,7 @@ namespace tensor_operation {
...
@@ -16,6 +16,7 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
//fp16_output
void
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances
(
void
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmFixedNK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmFixedNK
<
Row
,
Row
,
Row
,
...
@@ -42,33 +43,36 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(
...
@@ -42,33 +43,36 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(
PassThrough
,
PassThrough
,
AddBias
>>>&
instances
);
AddBias
>>>&
instances
);
#if 0
//fp32_output
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f
16_km
_kn_mn_instances(
void
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f
32_mk
_kn_mn_instances
(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<
Col
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmFixedNK
<
Row
,
Row
,
Row
,
Row_Tuple
,
Row_Tuple
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32_Tuple
,
F32_Tuple
,
F
16
,
F
32
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
AddBias
>>>&
instances
);
AddBias
>>>&
instances
);
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f
16_km
_nk_mn_instances(
void
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f
32_mk
_nk_mn_instances
(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<
Col
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmFixedNK
<
Row
,
Col
,
Col
,
Row_Tuple
,
Row_Tuple
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32_Tuple
,
F32_Tuple
,
F
16
,
F
32
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
AddBias
>>>&
instances
);
AddBias
>>>&
instances
);
#endif
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
...
@@ -105,6 +109,7 @@ struct DeviceOperationInstanceFactory<
...
@@ -105,6 +109,7 @@ struct DeviceOperationInstanceFactory<
{
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
//fp16_output
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
is_same_v
<
EDataType
,
half_t
>
)
{
{
...
@@ -118,15 +123,21 @@ struct DeviceOperationInstanceFactory<
...
@@ -118,15 +123,21 @@ struct DeviceOperationInstanceFactory<
{
{
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
}
}
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
}
//fp32_output
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
EDataType
,
float
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
is_same_v
<
ELayout
,
Row
>
)
{
{
//
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f
16_km
_kn_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f
32_mk
_kn_mn_instances
(
op_ptrs
);
}
}
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
is_same_v
<
ELayout
,
Row
>
)
{
{
//
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f
16_km
_nk_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f
32_mk
_nk_mn_instances
(
op_ptrs
);
}
}
}
}
return
op_ptrs
;
return
op_ptrs
;
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt
View file @
86699d01
add_instance_library
(
device_grouped_gemm_bias_instance
add_instance_library
(
device_grouped_gemm_bias_instance
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp
#device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_kn_mn_instance.cpp
#device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_nk_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instance.cpp
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment