Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a09f0171
Commit
a09f0171
authored
Oct 26, 2023
by
Jing Zhang
Browse files
add f16_f8_f16
parent
c6033af3
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
260 additions
and
2 deletions
+260
-2
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
...include/ck/library/tensor_operation_instance/gpu/gemm.hpp
+24
-0
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
...vice_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
+5
-2
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instance.cpp
...evice_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instance.cpp
+121
-0
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instance.cpp
...evice_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instance.cpp
+110
-0
No files found.
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
View file @
a09f0171
...
@@ -328,6 +328,16 @@ void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances(
...
@@ -328,6 +328,16 @@ void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances(
void
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances
(
void
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F8
,
F8
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
DeviceGemm
<
Row
,
Col
,
Row
,
F8
,
F8
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F16
,
F8
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#endif
template
<
typename
ALayout
,
template
<
typename
ALayout
,
...
@@ -549,6 +559,20 @@ struct DeviceOperationInstanceFactory<
...
@@ -549,6 +559,20 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances
(
op_ptrs
);
}
}
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
ck
::
half_t
>
&&
is_same_v
<
BDataType
,
ck
::
f8_t
>
&&
is_same_v
<
CDataType
,
ck
::
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances
(
op_ptrs
);
}
}
#endif
#endif
return
op_ptrs
;
return
op_ptrs
;
}
}
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
View file @
a09f0171
...
@@ -28,8 +28,8 @@ using S = ck::Sequence<Is...>;
...
@@ -28,8 +28,8 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
MNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
static
constexpr
auto
MNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
MNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
template
<
ck
::
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
>
template
<
ck
::
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
>
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
...
@@ -98,6 +98,9 @@ void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
...
@@ -98,6 +98,9 @@ void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
<
GemmDefault
>
{});
instances
,
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
<
GemmDefault
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
<
MNPadding
>
{});
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
<
MNKPadding
>
{});
instances
,
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
<
MNKPadding
>
{});
}
}
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instance.cpp
0 → 100644
View file @
a09f0171
This diff is collapsed.
Click to expand it.
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instance.cpp
0 → 100644
View file @
a09f0171
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment