Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e573a2a0
Unverified
Commit
e573a2a0
authored
Jun 30, 2022
by
Chao Liu
Committed by
GitHub
Jun 30, 2022
Browse files
Merge branch 'develop' into batched_gemm_g_stride_fix
parents
6adf3591
0dcb3496
Changes
258
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1001 additions
and
611 deletions
+1001
-611
library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp
...nsor_operation_instance/add_device_operation_instance.hpp
+9
-2
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+33
-0
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp
...ck/library/tensor_operation_instance/gpu/batched_gemm.hpp
+259
-0
library/include/ck/library/tensor_operation_instance/gpu/device_batched_gemm_instance.hpp
...r_operation_instance/gpu/device_batched_gemm_instance.hpp
+0
-203
library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp
...or_operation_instance/gpu/device_elementwise_instance.hpp
+4
-2
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_add_add_fastgelu_instance.hpp
...on_instance/gpu/device_gemm_add_add_fastgelu_instance.hpp
+0
-93
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_instance.hpp
...ry/tensor_operation_instance/gpu/device_gemm_instance.hpp
+0
-286
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp
...ion_instance/gpu/device_gemm_mean_squaremean_instance.hpp
+7
-7
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
...include/ck/library/tensor_operation_instance/gpu/gemm.hpp
+383
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp
...y/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp
+141
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp
.../ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp
+147
-0
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp
..._instance/gpu/reduce/device_reduce_instance_blockwise.hpp
+2
-2
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp
...u/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp
+2
-2
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp
...u/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp
+2
-2
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp
...u/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp
+2
-2
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.hpp
...u/reduce/device_reduce_instance_blockwise_f32_f32_f32.hpp
+2
-2
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.hpp
...u/reduce/device_reduce_instance_blockwise_f32_f64_f32.hpp
+2
-2
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.hpp
...u/reduce/device_reduce_instance_blockwise_f64_f64_f64.hpp
+2
-2
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp
...gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp
+2
-2
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.hpp
.../gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.hpp
+2
-2
No files found.
library/include/ck/library/tensor_operation_instance/device_operation_instance.hpp
→
library/include/ck/library/tensor_operation_instance/
add_
device_operation_instance.hpp
View file @
e573a2a0
...
@@ -4,14 +4,17 @@
...
@@ -4,14 +4,17 @@
#pragma once
#pragma once
#include <vector>
#include <vector>
#include <type_traits>
#include "ck/utility/functional2.hpp"
#include "ck/utility/functional2.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
template
<
typename
OpInstance
,
typename
NewOpInstances
>
template
<
typename
BaseOp
,
typename
NewOpInstances
>
void
add_device_operation_instances
(
std
::
vector
<
std
::
unique_ptr
<
OpInstance
>>&
op_instances
,
void
add_device_operation_instances
(
std
::
vector
<
std
::
unique_ptr
<
BaseOp
>>&
op_instances
,
const
NewOpInstances
&
new_op_instances
)
const
NewOpInstances
&
new_op_instances
)
{
{
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
NewOpInstances
>
,
1
>
{}([
&
](
auto
i
)
{
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
NewOpInstances
>
,
1
>
{}([
&
](
auto
i
)
{
...
@@ -19,10 +22,14 @@ void add_device_operation_instances(std::vector<std::unique_ptr<OpInstance>>& op
...
@@ -19,10 +22,14 @@ void add_device_operation_instances(std::vector<std::unique_ptr<OpInstance>>& op
using
NewOpInstance
=
remove_cvref_t
<
decltype
(
new_op_instance
)
>
;
using
NewOpInstance
=
remove_cvref_t
<
decltype
(
new_op_instance
)
>
;
static_assert
(
std
::
is_base_of_v
<
BaseOp
,
NewOpInstance
>
,
"wrong! NewOpInstance should be derived from BaseOp"
);
op_instances
.
push_back
(
std
::
make_unique
<
NewOpInstance
>
(
new_op_instance
));
op_instances
.
push_back
(
std
::
make_unique
<
NewOpInstance
>
(
new_op_instance
));
});
});
}
}
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
0 → 100644
View file @
e573a2a0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// aliasing, for commonly used type
using
F64
=
double
;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F16_F16
=
ck
::
Tuple
<
F16
,
F16
>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AddAddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
;
template
<
typename
DeviceOp
>
struct
DeviceOperationInstanceFactory
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp
0 → 100644
View file @
e573a2a0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Row
,
Row
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
Row
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Col
,
Col
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Row
,
Row
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemm
<
Row
,
Col
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceBatchedGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
float
>
&&
is_same_v
<
BDataType
,
float
>
&&
is_same_v
<
CDataType
,
float
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
bhalf_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
int8_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
CDataType
,
int8_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/device_batched_gemm_instance.hpp
deleted
100644 → 0
View file @
6adf3591
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_batched_gemm_instance
{
using
DeviceBatchedGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
void
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances
(
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>&
);
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
auto
get_device_batched_gemm_instances
()
{
std
::
vector
<
DeviceBatchedGemmNoOpPtr
>
op_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
float
>::
value
&&
is_same
<
BDataType
,
float
>::
value
&&
is_same
<
CDataType
,
float
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ADataType
,
bhalf_t
>::
value
&&
is_same
<
BDataType
,
bhalf_t
>::
value
&&
is_same
<
CDataType
,
bhalf_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ADataType
,
int8_t
>::
value
&&
is_same
<
BDataType
,
int8_t
>::
value
&&
is_same
<
CDataType
,
int8_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_batched_gemm_instance
::
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
}
// namespace device_batched_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp
View file @
e573a2a0
...
@@ -10,11 +10,12 @@
...
@@ -10,11 +10,12 @@
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/
add_
device_operation_instance.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
using
Normalize
=
ck
::
tensor_operation
::
element_wise
::
Normalize
;
using
Normalize
=
ck
::
tensor_operation
::
element_wise
::
Normalize
;
using
DeviceNormalizeFromMeanMeanSquarePtr
=
using
DeviceNormalizeFromMeanMeanSquarePtr
=
...
@@ -37,13 +38,14 @@ auto get_device_normalize_from_mean_meansquare_instances()
...
@@ -37,13 +38,14 @@ auto get_device_normalize_from_mean_meansquare_instances()
is_same
<
MeanSquareType
,
float
>::
value
&&
is_same
<
GammaDataType
,
half_t
>::
value
&&
is_same
<
MeanSquareType
,
float
>::
value
&&
is_same
<
GammaDataType
,
half_t
>::
value
&&
is_same
<
BetaDataType
,
half_t
>::
value
&&
is_same
<
OutputType
,
half_t
>::
value
)
is_same
<
BetaDataType
,
half_t
>::
value
&&
is_same
<
OutputType
,
half_t
>::
value
)
{
{
ck
::
tensor_operation
::
device
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances
(
op_ptrs
);
add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
return
op_ptrs
;
}
}
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_add_add_fastgelu_instance.hpp
deleted
100644 → 0
View file @
6adf3591
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
DeviceGemmAddAddFastGeluPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleDPtr
<
2
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
>
;
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmAddAddFastGeluPtr
>&
);
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmAddAddFastGeluPtr
>&
);
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmAddAddFastGeluPtr
>&
);
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmAddAddFastGeluPtr
>&
);
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
D0DataType
,
typename
D1DataType
,
typename
EDataType
,
typename
ALayout
,
typename
BLayout
,
typename
D0Layout
,
typename
D1Layout
,
typename
ELayout
>
auto
get_device_gemm_add_add_fastgelu_instances
()
{
std
::
vector
<
DeviceGemmAddAddFastGeluPtr
>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_same_v
<
ELayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
&&
is_same_v
<
ELayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
&&
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_same_v
<
ELayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
&&
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
&&
is_same_v
<
ELayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_instance.hpp
deleted
100644 → 0
View file @
6adf3591
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
DeviceGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
void
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
auto
get_device_gemm_instances
()
{
std
::
vector
<
DeviceGemmNoOpPtr
>
op_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
float
>::
value
&&
is_same
<
BDataType
,
float
>::
value
&&
is_same
<
CDataType
,
float
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ADataType
,
ck
::
bhalf_t
>::
value
&&
is_same
<
BDataType
,
ck
::
bhalf_t
>::
value
&&
is_same
<
CDataType
,
ck
::
bhalf_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ADataType
,
int8_t
>::
value
&&
is_same
<
BDataType
,
int8_t
>::
value
&&
is_same
<
CDataType
,
int8_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances
(
op_ptrs
);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp
View file @
e573a2a0
...
@@ -10,12 +10,12 @@
...
@@ -10,12 +10,12 @@
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/
add_
device_operation_instance.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
device_gemm_
instance
{
namespace
instance
{
using
DeviceGemmAddAddMeanSquareMeanPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmReducePtr
<
1
,
2
>
;
using
DeviceGemmAddAddMeanSquareMeanPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmReducePtr
<
1
,
2
>
;
...
@@ -45,7 +45,7 @@ auto get_device_gemm_add_add_mean_squaremean_instances()
...
@@ -45,7 +45,7 @@ auto get_device_gemm_add_add_mean_squaremean_instances()
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances
(
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances
(
op_ptrs
);
op_ptrs
);
}
}
...
@@ -53,7 +53,7 @@ auto get_device_gemm_add_add_mean_squaremean_instances()
...
@@ -53,7 +53,7 @@ auto get_device_gemm_add_add_mean_squaremean_instances()
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances
(
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances
(
op_ptrs
);
op_ptrs
);
}
}
...
@@ -61,7 +61,7 @@ auto get_device_gemm_add_add_mean_squaremean_instances()
...
@@ -61,7 +61,7 @@ auto get_device_gemm_add_add_mean_squaremean_instances()
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances
(
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances
(
op_ptrs
);
op_ptrs
);
}
}
...
@@ -69,7 +69,7 @@ auto get_device_gemm_add_add_mean_squaremean_instances()
...
@@ -69,7 +69,7 @@ auto get_device_gemm_add_add_mean_squaremean_instances()
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances
(
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances
(
op_ptrs
);
op_ptrs
);
}
}
...
@@ -78,7 +78,7 @@ auto get_device_gemm_add_add_mean_squaremean_instances()
...
@@ -78,7 +78,7 @@ auto get_device_gemm_add_add_mean_squaremean_instances()
return
op_ptrs
;
return
op_ptrs
;
}
}
}
// namespace
device_gemm_
instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
0 → 100644
View file @
e573a2a0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F64
,
F64
,
F64
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
F64
,
F64
,
F64
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
F64
,
F64
,
F64
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F64
,
F64
,
F64
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
float
>
&&
is_same_v
<
BDataType
,
float
>
&&
is_same_v
<
CDataType
,
float
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances
(
op_ptrs
);
add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances
(
op_ptrs
);
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances
(
op_ptrs
);
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances
(
op_ptrs
);
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
BDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
CDataType
,
ck
::
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
int8_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
CDataType
,
int8_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances
(
op_ptrs
);
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances
(
op_ptrs
);
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
op_ptrs
);
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances
(
op_ptrs
);
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp
0 → 100644
View file @
e573a2a0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
>>>&
);
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
>>>&
);
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
>>>&
);
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16_F16
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
>>>&
);
// GEMM + Add + Add + FastGelu
template
<
typename
ALayout
,
typename
BLayout
,
typename
DELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
DELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
>>
{
using
DeviceOp
=
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
DELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
DsDataType
,
Tuple
<
half_t
,
half_t
>>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
DELayout
,
Row
>
)
{
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
DELayout
,
Row
>
)
{
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
DELayout
,
Row
>
)
{
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
DELayout
,
Row
>
)
{
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/
device_
gemm_splitk
_instance
.hpp
→
library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp
View file @
e573a2a0
...
@@ -10,35 +10,52 @@
...
@@ -10,35 +10,52 @@
#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance
_factory
.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
device_gemm_
instance
{
namespace
instance
{
using
D
evice
G
emm
SplitKNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmSplitKPtr
<
void
add_d
evice
_g
emm
_xdl_splitk_f16_f16_f16_km_kn_mn_instances
(
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
std
::
vector
<
std
::
unique_ptr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
DeviceGemmSplitK
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
instances
)
;
void
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances
(
void
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmSplitKNoOpPtr
>&
);
std
::
vector
<
std
::
unique_ptr
<
void
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances
(
DeviceGemmSplitK
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
std
::
vector
<
DeviceGemmSplitKNoOpPtr
>&
);
instances
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmSplitKNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmSplitKNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
void
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmSplitKNoOpPtr
>&
);
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances
(
void
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmSplitKNoOpPtr
>&
);
std
::
vector
<
std
::
unique_ptr
<
void
add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances
(
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
std
::
vector
<
DeviceGemmSplitKNoOpPtr
>&
);
instances
);
void
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmSplitKNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Col
,
Row
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Col
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Row
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmSplitK
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
typename
ADataType
,
template
<
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
...
@@ -46,79 +63,85 @@ template <typename ADataType,
...
@@ -46,79 +63,85 @@ template <typename ADataType,
typename
ALayout
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
CLayout
>
typename
CLayout
>
auto
get_device_gemm_splitk_instances
()
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmSplitK
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
{
std
::
vector
<
DeviceGemmSplitKNoOpPtr
>
op_ptrs
;
using
DeviceOp
=
DeviceGemmSplitK
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
if
constexpr
(
is_same
<
ADataType
,
float
>::
value
&&
is_same
<
BDataType
,
float
>::
value
&&
static
auto
GetInstances
()
is_same
<
CDataType
,
float
>::
value
)
{
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
if
constexpr
(
is_same_v
<
ADataType
,
float
>
&&
is_same_v
<
BDataType
,
float
>
&&
is_same_v
<
CDataType
,
float
>
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same_v
<
CLayout
,
Row
>
)
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same_v
<
CLayout
,
Row
>
)
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same_v
<
CLayout
,
Row
>
)
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances
(
op_ptrs
);
}
}
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
else
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same_v
<
CDataType
,
half_t
>
)
is_same
<
CDataType
,
half_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same_v
<
CLayout
,
Row
>
)
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same_v
<
CLayout
,
Row
>
)
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same_v
<
CLayout
,
Row
>
)
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
}
}
return
op_ptrs
;
return
op_ptrs
;
}
}
};
}
// namespace
device_gemm_
instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp
View file @
e573a2a0
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
device_reduce_
instance
{
namespace
instance
{
using
reduce_configuration_1_instances_blockwise
=
std
::
tuple
<
using
reduce_configuration_1_instances_blockwise
=
std
::
tuple
<
// clang-format off
// clang-format off
...
@@ -174,7 +174,7 @@ void add_device_reduce_instance_blockwise(
...
@@ -174,7 +174,7 @@ void add_device_reduce_instance_blockwise(
Rank, \
Rank, \
NumReduceDim)
NumReduceDim)
}
// namespace
device_reduce_
instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp
View file @
e573a2a0
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
device_reduce_
instance
{
namespace
instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...
@@ -53,7 +53,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1);
...
@@ -53,7 +53,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
4
,
0
,
1
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
4
,
0
,
1
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace
device_reduce_
instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp
View file @
e573a2a0
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
device_reduce_
instance
{
namespace
instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...
@@ -40,7 +40,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
...
@@ -40,7 +40,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace
device_reduce_
instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp
View file @
e573a2a0
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
device_reduce_
instance
{
namespace
instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...
@@ -28,7 +28,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
...
@@ -28,7 +28,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace
device_reduce_
instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.hpp
View file @
e573a2a0
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
device_reduce_
instance
{
namespace
instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...
@@ -52,7 +52,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1);
...
@@ -52,7 +52,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
1
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
float
,
float
,
float
,
4
,
0
,
1
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace
device_reduce_
instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.hpp
View file @
e573a2a0
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
device_reduce_
instance
{
namespace
instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...
@@ -28,7 +28,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 1);
...
@@ -28,7 +28,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID
(
float
,
double
,
float
,
7
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
float
,
double
,
float
,
7
,
0
,
0
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace
device_reduce_
instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.hpp
View file @
e573a2a0
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
device_reduce_
instance
{
namespace
instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...
@@ -52,7 +52,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1);
...
@@ -52,7 +52,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
1
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
double
,
double
,
double
,
4
,
0
,
1
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace
device_reduce_
instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp
View file @
e573a2a0
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
device_reduce_
instance
{
namespace
instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...
@@ -24,7 +24,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1);
...
@@ -24,7 +24,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID
(
int8_t
,
int32_t
,
int8_t
,
5
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
int8_t
,
int32_t
,
int8_t
,
5
,
0
,
0
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace
device_reduce_
instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.hpp
View file @
e573a2a0
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
device_reduce_
instance
{
namespace
instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...
@@ -40,7 +40,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1);
...
@@ -40,7 +40,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
4
,
0
,
1
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
int8_t
,
int8_t
,
int8_t
,
4
,
0
,
1
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace
device_reduce_
instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
Prev
1
2
3
4
5
6
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment