Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4100d1d8
"vscode:/vscode.git/clone" did not exist on "22e1e637582879fd520f15262edf2370474415e5"
Commit
4100d1d8
authored
Aug 23, 2023
by
Alan Turner
Browse files
Merge remote-tracking branch 'origin/develop' into migx-flash-attn
parents
48717006
c8a8385f
Changes
609
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
184 additions
and
243 deletions
+184
-243
library/include/ck/library/tensor_operation_instance/gpu/normalization.hpp
...k/library/tensor_operation_instance/gpu/normalization.hpp
+10
-7
library/include/ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp
...e/ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp
+0
-111
library/include/ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp
...e/ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp
+47
-34
library/include/ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp
...operation_instance/gpu/quantization/gemm_quantization.hpp
+12
-3
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp
...uped_convolution_bias_forward_perchannel_quantization.hpp
+10
-3
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp
...rouped_convolution_bias_forward_perlayer_quantization.hpp
+10
-3
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp
...n/grouped_convolution_forward_perchannel_quantization.hpp
+8
-3
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp
...ion/grouped_convolution_forward_perlayer_quantization.hpp
+8
-3
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp
..._instance/gpu/reduce/device_reduce_instance_blockwise.hpp
+4
-4
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp
...u/reduce/device_reduce_instance_multiblock_atomic_add.hpp
+4
-4
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp
...instance/gpu/reduce/device_reduce_instance_threadwise.hpp
+2
-2
library/include/ck/library/tensor_operation_instance/gpu/softmax.hpp
...lude/ck/library/tensor_operation_instance/gpu/softmax.hpp
+62
-37
library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance.hpp
..._instance/gpu/softmax/device_softmax_f16_f16_instance.hpp
+0
-22
library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.hpp
...softmax/device_softmax_f16_f16_instance_rank3_reduce1.hpp
+1
-1
library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.hpp
...softmax/device_softmax_f16_f16_instance_rank3_reduce2.hpp
+1
-1
library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.hpp
...softmax/device_softmax_f16_f16_instance_rank3_reduce3.hpp
+1
-1
library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.hpp
...softmax/device_softmax_f16_f16_instance_rank4_reduce1.hpp
+1
-1
library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.hpp
...softmax/device_softmax_f16_f16_instance_rank4_reduce2.hpp
+1
-1
library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.hpp
...softmax/device_softmax_f16_f16_instance_rank4_reduce3.hpp
+1
-1
library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.hpp
...softmax/device_softmax_f16_f16_instance_rank4_reduce4.hpp
+1
-1
No files found.
library/include/ck/library/tensor_operation_instance/gpu/normalization.hpp
View file @
4100d1d8
...
...
@@ -16,7 +16,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef CK_ENABLE_FP16
// FP16
void
add_device_normalization_rank_2_1_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
2
,
1
>>>&
);
...
...
@@ -26,7 +26,8 @@ void add_device_normalization_rank_4_3_f16_instances(
void
add_device_normalization_rank_5_3_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
5
,
3
>>>&
);
#endif
#ifdef CK_ENABLE_FP32
// FP32
void
add_device_normalization_rank_2_1_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
2
,
1
>>>&
);
...
...
@@ -36,7 +37,7 @@ void add_device_normalization_rank_4_3_f32_instances(
void
add_device_normalization_rank_5_3_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
5
,
3
>>>&
);
#endif
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
...
...
@@ -65,7 +66,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
XDataType
,
F16
>
&&
is_same_v
<
GammaDataType
,
F16
>
&&
is_same_v
<
BetaDataType
,
F16
>
&&
is_same_v
<
YDataType
,
F16
>
)
{
...
...
@@ -82,8 +83,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
add_device_normalization_rank_5_3_f16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
GammaDataType
,
F32
>
&&
is_same_v
<
BetaDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F32
>
)
#endif
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
GammaDataType
,
F32
>
&&
is_same_v
<
BetaDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
2
&&
NumReduceDim
==
1
)
{
...
...
@@ -98,7 +101,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
add_device_normalization_rank_5_3_f32_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp
deleted
100644 → 0
View file @
48717006
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
static
constexpr
auto
InOutRank
=
4
;
static
constexpr
auto
WindowRank
=
2
;
static
constexpr
auto
MaxOp
=
ck
::
ReduceTensorOp
::
MAX
;
static
constexpr
auto
AvgOp
=
ck
::
ReduceTensorOp
::
AVG
;
// FP16
void
add_device_pool2d_fwd_nhwc_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F16
,
F16
,
I32
,
MaxOp
,
false
>>>&
);
void
add_device_pool2d_fwd_nhwc_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F16
,
F16
,
I32
,
AvgOp
,
false
>>>&
);
// FP16 - return index
void
add_device_pool2d_fwd_nhwc_index_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F16
,
F16
,
I32
,
MaxOp
,
true
>>>&
);
// FP32
void
add_device_pool2d_fwd_nhwc_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F32
,
F32
,
I32
,
MaxOp
,
false
>>>&
);
void
add_device_pool2d_fwd_nhwc_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F32
,
F32
,
I32
,
AvgOp
,
false
>>>&
);
// FP32 - return index
void
add_device_pool2d_fwd_nhwc_index_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F32
,
F32
,
I32
,
MaxOp
,
true
>>>&
);
template
<
typename
InDataType
,
typename
OutDataType
,
typename
IndexDataType
,
ck
::
ReduceTensorOp
ReduceOpId
,
bool
OutputIndex
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DevicePoolFwd
<
InOutRank
,
WindowRank
,
InDataType
,
OutDataType
,
IndexDataType
,
ReduceOpId
,
OutputIndex
>>
{
using
DeviceOp
=
DevicePoolFwd
<
InOutRank
,
WindowRank
,
InDataType
,
OutDataType
,
IndexDataType
,
ReduceOpId
,
OutputIndex
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
InDataType
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
{
if
constexpr
(
OutputIndex
&&
ReduceOpId
==
MaxOp
)
{
add_device_pool2d_fwd_nhwc_index_f16_instances
(
op_ptrs
);
}
else
{
add_device_pool2d_fwd_nhwc_f16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
InDataType
,
F32
>
&&
is_same_v
<
OutDataType
,
F32
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
{
if
constexpr
(
OutputIndex
&&
ReduceOpId
==
MaxOp
)
{
add_device_pool2d_fwd_nhwc_index_f32_instances
(
op_ptrs
);
}
else
{
add_device_pool2d_fwd_nhwc_f32_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp
View file @
4100d1d8
...
...
@@ -22,38 +22,41 @@ static constexpr auto WindowRank = 3;
static
constexpr
auto
MaxOp
=
ck
::
ReduceTensorOp
::
MAX
;
static
constexpr
auto
AvgOp
=
ck
::
ReduceTensorOp
::
AVG
;
#ifdef CK_ENABLE_FP16
// FP16
void
add_device_pool3d_fwd_ndhwc_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F16
,
F16
,
I32
,
MaxOp
,
false
>>>&
);
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F16
,
F16
,
I32
,
NDHWC
,
NDHWC
,
MaxOp
,
false
>>>&
);
void
add_device_pool3d_fwd_ndhwc_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F16
,
F16
,
I32
,
AvgOp
,
false
>>>&
);
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F16
,
F16
,
I32
,
NDHWC
,
NDHWC
,
AvgOp
,
false
>>>&
);
// FP16 - return index
void
add_device_pool3d_fwd_ndhwc_index_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F16
,
F16
,
I32
,
MaxOp
,
true
>>>&
);
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F16
,
F16
,
I32
,
NDHWC
,
NDHWC
,
MaxOp
,
true
>>>&
);
#endif
#ifdef CK_ENABLE_FP32
// FP32
void
add_device_pool3d_fwd_ndhwc_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F32
,
F32
,
I32
,
MaxOp
,
false
>>>&
);
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F32
,
F32
,
I32
,
NDHWC
,
NDHWC
,
MaxOp
,
false
>>>&
);
void
add_device_pool3d_fwd_ndhwc_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F32
,
F32
,
I32
,
AvgOp
,
false
>>>&
);
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F32
,
F32
,
I32
,
NDHWC
,
NDHWC
,
AvgOp
,
false
>>>&
);
// FP32 - return index
void
add_device_pool3d_fwd_ndhwc_index_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F32
,
F32
,
I32
,
MaxOp
,
true
>>>&
);
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F32
,
F32
,
I32
,
NDHWC
,
NDHWC
,
MaxOp
,
true
>>>&
);
#endif
template
<
typename
InDataType
,
typename
OutDataType
,
typename
IndexDataType
,
typename
InLayout
,
typename
OutLayout
,
ck
::
ReduceTensorOp
ReduceOpId
,
bool
OutputIndex
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DevicePoolFwd
<
InOutRank
,
...
...
@@ -61,6 +64,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
InDataType
,
OutDataType
,
IndexDataType
,
InLayout
,
OutLayout
,
ReduceOpId
,
OutputIndex
>>
{
...
...
@@ -69,36 +74,44 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
InDataType
,
OutDataType
,
IndexDataType
,
InLayout
,
OutLayout
,
ReduceOpId
,
OutputIndex
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
InDataType
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
{
if
constexpr
(
OutputIndex
&&
ReduceOpId
==
MaxOp
)
{
add_device_pool3d_fwd_ndhwc_index_f16_instances
(
op_ptrs
);
}
else
{
add_device_pool3d_fwd_ndhwc_f16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
InDataType
,
F32
>
&&
is_same_v
<
OutDataType
,
F32
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
if
constexpr
(
is_same_v
<
InLayout
,
NDHWC
>
&&
is_same_v
<
OutLayout
,
NDHWC
>
)
{
if
constexpr
(
OutputIndex
&&
ReduceOpId
==
MaxOp
)
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
{
add_device_pool3d_fwd_ndhwc_index_f32_instances
(
op_ptrs
);
if
constexpr
(
OutputIndex
&&
ReduceOpId
==
MaxOp
)
{
add_device_pool3d_fwd_ndhwc_index_f16_instances
(
op_ptrs
);
}
else
{
add_device_pool3d_fwd_ndhwc_f16_instances
(
op_ptrs
);
}
}
else
#endif
#ifdef CK_ENABLE_FP32
if
constexpr
(
is_same_v
<
InDataType
,
F32
>
&&
is_same_v
<
OutDataType
,
F32
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
{
add_device_pool3d_fwd_ndhwc_f32_instances
(
op_ptrs
);
if
constexpr
(
OutputIndex
&&
ReduceOpId
==
MaxOp
)
{
add_device_pool3d_fwd_ndhwc_index_f32_instances
(
op_ptrs
);
}
else
{
add_device_pool3d_fwd_ndhwc_f32_instances
(
op_ptrs
);
}
}
#endif
}
return
op_ptrs
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp
View file @
4100d1d8
...
...
@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_ENABLE_INT8
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef DL_KERNELS
// Layout(A, B, C) = [Col, Row, Row]
void
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
...
...
@@ -76,7 +76,7 @@ void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
PassThrough
,
Activation_Mul_Clamp
<
PassThrough
>>>>&
instances
);
#endif
// Layout(A, B, C) = [Col, Row, Row]
void
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
...
...
@@ -181,7 +181,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
{
#ifdef DL_KERNELS
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instances
(
op_ptrs
);
#endif
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances
(
op_ptrs
);
}
}
...
...
@@ -190,7 +192,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
{
#ifdef DL_KERNELS
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances
(
op_ptrs
);
#endif
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances
(
op_ptrs
);
}
}
...
...
@@ -199,7 +203,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
{
#ifdef DL_KERNELS
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
op_ptrs
);
#endif
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
op_ptrs
);
}
}
...
...
@@ -208,7 +214,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
{
#ifdef DL_KERNELS
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instances
(
op_ptrs
);
#endif
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances
(
op_ptrs
);
}
}
...
...
@@ -222,3 +230,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
\ No newline at end of file
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp
View file @
4100d1d8
...
...
@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_ENABLE_INT8
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef DL_KERNELS
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void
add_device_conv2d_dl_bias_perchannel_quantization_int8_instances
(
std
::
vector
<
...
...
@@ -64,7 +64,7 @@ void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
PassThrough
,
Add_Mul2_Activation_Mul_Clamp
<
TanH
>>>>&
instances
);
#endif
void
add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
...
...
@@ -163,12 +163,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
{
#ifdef DL_KERNELS
add_device_conv2d_dl_bias_perchannel_quantization_int8_instances
(
op_ptrs
);
#endif
add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
Activation
,
Relu
>
)
{
#ifdef DL_KERNELS
add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances
(
op_ptrs
);
#endif
add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances
(
op_ptrs
);
}
}
...
...
@@ -229,7 +233,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
if
constexpr
(
is_same_v
<
Activation
,
TanH
>
)
{
#ifdef DL_KERNELS
add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances
(
op_ptrs
);
#endif
add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances
(
op_ptrs
);
}
}
...
...
@@ -243,3 +249,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp
View file @
4100d1d8
...
...
@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_ENABLE_INT8
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef DL_KERNELS
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void
add_device_conv2d_dl_bias_perlayer_quantization_int8_instances
(
std
::
vector
<
...
...
@@ -63,7 +63,7 @@ void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
PassThrough
,
Add_Mul_Activation_Mul_Clamp
<
TanH
>>>>&
instances
);
#endif
void
add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
...
...
@@ -161,12 +161,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
{
#ifdef DL_KERNELS
add_device_conv2d_dl_bias_perlayer_quantization_int8_instances
(
op_ptrs
);
#endif
add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
Activation
,
Relu
>
)
{
#ifdef DL_KERNELS
add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances
(
op_ptrs
);
#endif
add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances
(
op_ptrs
);
}
}
...
...
@@ -227,7 +231,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
if
constexpr
(
is_same_v
<
Activation
,
TanH
>
)
{
#ifdef DL_KERNELS
add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances
(
op_ptrs
);
#endif
add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances
(
op_ptrs
);
}
}
...
...
@@ -241,3 +247,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp
View file @
4100d1d8
...
...
@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_ENABLE_INT8
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef DL_KERNELS
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void
add_device_conv2d_dl_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
...
...
@@ -47,7 +47,7 @@ void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(
PassThrough
,
Activation_Mul2_Clamp
<
Relu
>>>>&
instances
);
#endif
void
add_device_conv2d_xdl_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
NHWGC
,
...
...
@@ -128,12 +128,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
{
#ifdef DL_KERNELS
add_device_conv2d_dl_perchannel_quantization_int8_instances
(
op_ptrs
);
#endif
add_device_conv2d_xdl_perchannel_quantization_int8_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
Activation
,
Relu
>
)
{
#ifdef DL_KERNELS
add_device_conv2d_dl_relu_perchannel_quantization_int8_instances
(
op_ptrs
);
#endif
add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances
(
op_ptrs
);
}
}
...
...
@@ -147,3 +151,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp
View file @
4100d1d8
...
...
@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_ENABLE_INT8
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef DL_KERNELS
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void
add_device_conv2d_dl_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
...
...
@@ -47,7 +47,7 @@ void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(
PassThrough
,
Activation_Mul_Clamp
<
Relu
>>>>&
instances
);
#endif
void
add_device_conv2d_xdl_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
NHWGC
,
...
...
@@ -125,12 +125,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
{
#ifdef DL_KERNELS
add_device_conv2d_dl_perlayer_quantization_int8_instances
(
op_ptrs
);
#endif
add_device_conv2d_xdl_perlayer_quantization_int8_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
Activation
,
Relu
>
)
{
#ifdef DL_KERNELS
add_device_conv2d_dl_relu_perlayer_quantization_int8_instances
(
op_ptrs
);
#endif
add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances
(
op_ptrs
);
}
}
...
...
@@ -144,3 +148,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp
View file @
4100d1d8
...
...
@@ -89,13 +89,13 @@ void add_device_reduce_instance_blockwise(
{
static_for
<
0
,
std
::
tuple_size
<
reduce_configuration_1_instances_blockwise
>::
value
,
1
>
{}(
[
&
](
auto
i
)
{
using
cfg1
=
remove_cvref_t
<
decltype
(
std
::
get
<
i
.
value
>
(
reduce_configuration_1_instances_blockwise
{}))
>
;
using
cfg1
=
remove_cvref_t
<
decltype
(
std
::
get
<
i
.
value
>
(
reduce_configuration_1_instances_blockwise
{}))
>
;
static_for
<
0
,
std
::
tuple_size
<
reduce_configuration_2_instances_blockwise
>::
value
,
1
>
{}(
[
&
](
auto
j
)
{
using
cfg2
=
remove_cvref_t
<
decltype
(
std
::
get
<
j
.
value
>
(
reduce_configuration_2_instances_blockwise
{}))
>
;
using
cfg2
=
remove_cvref_t
<
decltype
(
std
::
get
<
j
.
value
>
(
reduce_configuration_2_instances_blockwise
{}))
>
;
using
ReduceOpInstance
=
DeviceReduceMultiBlock
<
InDataType
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp
View file @
4100d1d8
...
...
@@ -90,14 +90,14 @@ void add_device_reduce_instance_multiblock_atomic_add(
static_for
<
0
,
std
::
tuple_size
<
reduce_configuration_1_instances_multiblock_atomic_add
>::
value
,
1
>
{}([
&
](
auto
i
)
{
using
cfg1
=
remove_cvref_t
<
decltype
(
std
::
get
<
i
.
value
>
(
reduce_configuration_1_instances_multiblock_atomic_add
{}))
>
;
using
cfg1
=
remove_cvref_t
<
decltype
(
std
::
get
<
i
.
value
>
(
reduce_configuration_1_instances_multiblock_atomic_add
{}))
>
;
static_for
<
0
,
std
::
tuple_size
<
reduce_configuration_2_instances_multiblock_atomic_add
>::
value
,
1
>
{}([
&
](
auto
j
)
{
using
cfg2
=
remove_cvref_t
<
decltype
(
std
::
get
<
j
.
value
>
(
reduce_configuration_2_instances_multiblock_atomic_add
{}))
>
;
using
cfg2
=
remove_cvref_t
<
decltype
(
std
::
get
<
j
.
value
>
(
reduce_configuration_2_instances_multiblock_atomic_add
{}))
>
;
using
ReduceOpInstance
=
DeviceReduceMultiBlock
<
InDataType
,
AccDataType
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp
View file @
4100d1d8
...
...
@@ -77,8 +77,8 @@ void add_device_reduce_instance_threadwise(
static_for
<
0
,
std
::
tuple_size
<
reduce_configuration_2_instances_threadwise
>::
value
,
1
>
{}(
[
&
](
auto
j
)
{
using
cfg2
=
remove_cvref_t
<
decltype
(
std
::
get
<
j
.
value
>
(
reduce_configuration_2_instances_threadwise
{}))
>
;
using
cfg2
=
remove_cvref_t
<
decltype
(
std
::
get
<
j
.
value
>
(
reduce_configuration_2_instances_threadwise
{}))
>
;
using
ReduceOpInstance
=
DeviceReduceThreadWise
<
InDataType
,
AccDataType
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/softmax.hpp
View file @
4100d1d8
...
...
@@ -9,64 +9,89 @@
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_softmax_f16_f16_rank3_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
3
>>&
);
void
add_device_softmax_f16_f16_rank4_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
4
>>&
);
void
add_device_softmax_f32_f32_rank3_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
3
>>&
);
void
add_device_softmax_f32_f32_rank4_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
4
>>&
);
void
add_device_softmax_i8_i8_rank3_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
I8
,
F32
,
I8
,
PassThrough
,
PassThrough
,
3
>>&
);
void
add_device_softmax_i8_i8_rank4_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
I8
,
F32
,
I8
,
PassThrough
,
PassThrough
,
4
>>&
);
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
index_t
Rank
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceSoftmax
<
InDataType
,
AccDataType
,
OutDataType
,
PassThrough
,
PassThrough
,
Rank
>>
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
index_t
Rank
,
index_t
NumReduceDim
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceSoftmax
<
InDataType
,
AccDataType
,
OutDataType
,
PassThrough
,
PassThrough
,
Rank
,
NumReduceDim
>>
{
using
DeviceOp
=
DeviceSoftmax
<
InDataType
,
AccDataType
,
OutDataType
,
PassThrough
,
PassThrough
,
Rank
>
;
using
DeviceOp
=
DeviceSoftmax
<
InDataType
,
AccDataType
,
OutDataType
,
PassThrough
,
PassThrough
,
Rank
,
NumReduceDim
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_FP16
if
constexpr
(
std
::
is_same_v
<
InDataType
,
F16
>
&&
std
::
is_same_v
<
AccDataType
,
F32
>
&&
std
::
is_same_v
<
OutDataType
,
F16
>
)
{
if
constexpr
(
Rank
==
3
)
add_device_softmax_f16_f16_rank3_instances
(
op_ptrs
);
{
if
constexpr
(
NumReduceDim
==
1
)
add_device_softmax_f16_f16_rank3_reduce1_instances
(
op_ptrs
);
else
if
constexpr
(
NumReduceDim
==
2
)
add_device_softmax_f16_f16_rank3_reduce2_instances
(
op_ptrs
);
else
if
constexpr
(
NumReduceDim
==
3
)
add_device_softmax_f16_f16_rank3_reduce3_instances
(
op_ptrs
);
}
else
if
constexpr
(
Rank
==
4
)
add_device_softmax_f16_f16_rank4_instances
(
op_ptrs
);
{
if
constexpr
(
NumReduceDim
==
1
)
add_device_softmax_f16_f16_rank4_reduce1_instances
(
op_ptrs
);
else
if
constexpr
(
NumReduceDim
==
2
)
add_device_softmax_f16_f16_rank4_reduce2_instances
(
op_ptrs
);
else
if
constexpr
(
NumReduceDim
==
3
)
add_device_softmax_f16_f16_rank4_reduce3_instances
(
op_ptrs
);
else
if
constexpr
(
NumReduceDim
==
4
)
add_device_softmax_f16_f16_rank4_reduce4_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
std
::
is_same_v
<
InDataType
,
F32
>
&&
std
::
is_same_v
<
AccDataType
,
F32
>
&&
std
::
is_same_v
<
OutDataType
,
F32
>
)
#endif
#ifdef CK_ENABLE_FP32
if
constexpr
(
std
::
is_same_v
<
InDataType
,
F32
>
&&
std
::
is_same_v
<
AccDataType
,
F32
>
&&
std
::
is_same_v
<
OutDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
3
)
add_device_softmax_f32_f32_rank3_instances
(
op_ptrs
);
{
if
constexpr
(
NumReduceDim
==
1
)
add_device_softmax_f32_f32_rank3_reduce1_instances
(
op_ptrs
);
else
if
constexpr
(
NumReduceDim
==
2
)
add_device_softmax_f32_f32_rank3_reduce2_instances
(
op_ptrs
);
else
if
constexpr
(
NumReduceDim
==
3
)
add_device_softmax_f32_f32_rank3_reduce3_instances
(
op_ptrs
);
}
else
if
constexpr
(
Rank
==
4
)
add_device_softmax_f32_f32_rank4_instances
(
op_ptrs
);
{
if
constexpr
(
NumReduceDim
==
1
)
add_device_softmax_f32_f32_rank4_reduce1_instances
(
op_ptrs
);
else
if
constexpr
(
NumReduceDim
==
2
)
add_device_softmax_f32_f32_rank4_reduce2_instances
(
op_ptrs
);
else
if
constexpr
(
NumReduceDim
==
3
)
add_device_softmax_f32_f32_rank4_reduce3_instances
(
op_ptrs
);
else
if
constexpr
(
NumReduceDim
==
4
)
add_device_softmax_f32_f32_rank4_reduce4_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
std
::
is_same_v
<
InDataType
,
I8
>
&&
std
::
is_same_v
<
AccDataType
,
F32
>
&&
std
::
is_same_v
<
OutDataType
,
I8
>
)
{
if
constexpr
(
Rank
==
3
)
add_device_softmax_i8_i8_rank3_instances
(
op_ptrs
);
else
if
constexpr
(
Rank
==
4
)
add_device_softmax_i8_i8_rank4_instances
(
op_ptrs
);
}
#endif
return
op_ptrs
;
}
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance.hpp
deleted
100644 → 0
View file @
48717006
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_softmax_f16_f16_rank3_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
3
>>&
instances
);
void
add_device_softmax_f16_f16_rank4_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
4
>>&
instances
);
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.hpp
View file @
4100d1d8
...
...
@@ -14,7 +14,7 @@ namespace device {
namespace
instance
{
void
add_device_softmax_f16_f16_rank3_reduce1_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
3
>>&
instances
);
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
3
,
1
>>&
instances
);
}
// namespace instance
}
// namespace device
...
...
library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.hpp
View file @
4100d1d8
...
...
@@ -14,7 +14,7 @@ namespace device {
namespace
instance
{
void
add_device_softmax_f16_f16_rank3_reduce2_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
3
>>&
instances
);
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
3
,
2
>>&
instances
);
}
// namespace instance
}
// namespace device
...
...
library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.hpp
View file @
4100d1d8
...
...
@@ -14,7 +14,7 @@ namespace device {
namespace
instance
{
void
add_device_softmax_f16_f16_rank3_reduce3_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
3
>>&
instances
);
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
3
,
3
>>&
instances
);
}
// namespace instance
}
// namespace device
...
...
library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.hpp
View file @
4100d1d8
...
...
@@ -14,7 +14,7 @@ namespace device {
namespace
instance
{
void
add_device_softmax_f16_f16_rank4_reduce1_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
4
>>&
instances
);
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
4
,
1
>>&
instances
);
}
// namespace instance
}
// namespace device
...
...
library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.hpp
View file @
4100d1d8
...
...
@@ -14,7 +14,7 @@ namespace device {
namespace
instance
{
void
add_device_softmax_f16_f16_rank4_reduce2_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
4
>>&
instances
);
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
4
,
2
>>&
instances
);
}
// namespace instance
}
// namespace device
...
...
library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.hpp
View file @
4100d1d8
...
...
@@ -14,7 +14,7 @@ namespace device {
namespace
instance
{
void
add_device_softmax_f16_f16_rank4_reduce3_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
4
>>&
instances
);
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
4
,
3
>>&
instances
);
}
// namespace instance
}
// namespace device
...
...
library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.hpp
View file @
4100d1d8
...
...
@@ -14,7 +14,7 @@ namespace device {
namespace
instance
{
void
add_device_softmax_f16_f16_rank4_reduce4_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
4
>>&
instances
);
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
4
,
4
>>&
instances
);
}
// namespace instance
}
// namespace device
...
...
Prev
1
…
11
12
13
14
15
16
17
18
19
…
31
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment