Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
fd11a4a1
Unverified
Commit
fd11a4a1
authored
Apr 17, 2023
by
rocking5566
Committed by
GitHub
Apr 17, 2023
Browse files
Add (#677)
parent
fc26d42a
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
65 additions
and
2 deletions
+65
-2
client_example/18_groupnorm/groupnorm_swish.cpp
client_example/18_groupnorm/groupnorm_swish.cpp
+2
-2
library/include/ck/library/tensor_operation_instance/gpu/normalization_swish.hpp
...ary/tensor_operation_instance/gpu/normalization_swish.hpp
+12
-0
library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt
...ensor_operation_instance/gpu/normalization/CMakeLists.txt
+1
-0
library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_f32_f32_f16_instance.cpp
...ation/device_groupnorm_swish_f16_f32_f32_f16_instance.cpp
+24
-0
library/src/tensor_operation_instance/gpu/normalization/normalization_instance_common.hpp
...tance/gpu/normalization/normalization_instance_common.hpp
+26
-0
No files found.
client_example/18_groupnorm/groupnorm_swish.cpp
View file @
fd11a4a1
...
...
@@ -13,8 +13,8 @@
#include "ck/library/tensor_operation_instance/gpu/normalization_swish.hpp"
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_
t
;
using
BetaDataType
=
ck
::
half_
t
;
using
GammaDataType
=
floa
t
;
using
BetaDataType
=
floa
t
;
using
YDataType
=
ck
::
half_t
;
using
ComputeDataType
=
float
;
using
Swish
=
ck
::
tensor_operation
::
element_wise
::
Swish
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/normalization_swish.hpp
View file @
fd11a4a1
...
...
@@ -25,6 +25,10 @@ void add_device_normalization_rank_5_3_swish_f16_instances(
void
add_device_normalization_rank_5_3_swish_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F32
,
F32
,
F32
,
F32
,
F32
,
Swish
,
5
,
3
>>>&
);
// [x, gamma, beta, y] = [f16, f32, f32, f16]
void
add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F16
,
F32
,
F32
,
F32
,
F16
,
Swish
,
5
,
3
>>>&
);
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
...
...
@@ -70,6 +74,14 @@ struct DeviceOperationInstanceFactory<
add_device_normalization_rank_5_3_swish_f32_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F16
>
&&
is_same_v
<
GammaDataType
,
F32
>
&&
is_same_v
<
BetaDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F16
>
)
{
if
constexpr
(
Rank
==
5
&&
NumReduceDim
==
3
)
{
add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
...
...
library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt
View file @
fd11a4a1
...
...
@@ -7,4 +7,5 @@ add_instance_library(device_normalization_instance
device_groupnorm_f32_instance.cpp
device_groupnorm_swish_f16_instance.cpp
device_groupnorm_swish_f32_instance.cpp
device_groupnorm_swish_f16_f32_f32_f16_instance.cpp
)
library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_f32_f32_f16_instance.cpp
0 → 100644
View file @
fd11a4a1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_instance_common.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
Swish
=
ck
::
tensor_operation
::
element_wise
::
Swish
;
void
add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F16
,
F32
,
F32
,
F32
,
F16
,
Swish
,
5
,
3
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_normalization_f16_f32_f32_f16_instances
<
Swish
,
5
,
3
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/normalization/normalization_instance_common.hpp
View file @
fd11a4a1
...
...
@@ -69,6 +69,32 @@ using device_normalization_f32_instances = std::tuple<
// clang-format on
>
;
template
<
typename
OutElementwise
,
index_t
Rank
,
index_t
Reduce
>
using
device_normalization_f16_f32_f32_f16_instances
=
std
::
tuple
<
// clang-format off
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
DeviceNormalizationImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
128
,
1
,
128
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
>
,
// irregular size
DeviceNormalizationImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
>
,
// irregular size
DeviceNormalizationImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
512
,
1
,
512
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
>
,
// irregular size
DeviceNormalizationImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
1024
,
1
,
1024
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
>
,
// irregular size
DeviceNormalizationImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
2
,
1
,
2
,
1
,
2
,
1
,
2
,
2
>
,
// irregular size
DeviceNormalizationImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
128
,
1
,
128
,
1
,
4
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
128
,
1
,
128
,
1
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
128
,
1
,
128
,
1
,
16
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
128
,
1
,
128
,
1
,
32
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
4
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
16
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
2
,
16
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
32
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
512
,
1
,
512
,
1
,
4
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
512
,
1
,
512
,
1
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
512
,
1
,
512
,
2
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
1024
,
1
,
1024
,
1
,
4
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceNormalizationImpl
<
F16
,
F32
,
F32
,
F32
,
F16
,
OutElementwise
,
Rank
,
Reduce
,
1024
,
1
,
1024
,
1
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment