Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c6b98c98
Commit
c6b98c98
authored
Oct 24, 2023
by
Astha Rai
Browse files
added fp 16 type check in unary square
parent
aa61ccf0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
38 additions
and
28 deletions
+38
-28
example/65_hip_tensor_permute/elementwise_permute_4D_fp16_ht.cpp
.../65_hip_tensor_permute/elementwise_permute_4D_fp16_ht.cpp
+24
-14
example/65_hip_tensor_permute/elementwise_permute_4D_fp32_ht.cpp
.../65_hip_tensor_permute/elementwise_permute_4D_fp32_ht.cpp
+12
-12
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+2
-2
No files found.
example/65_hip_tensor_permute/elementwise_permute_4D_fp16_ht.cpp
View file @
c6b98c98
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl
_ht
.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
...
@@ -18,25 +18,35 @@ using ADataType = F16;
...
@@ -18,25 +18,35 @@ using ADataType = F16;
using
BDataType
=
F16
;
using
BDataType
=
F16
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
UnaryOp
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
;
// ck::index_t scalar_mult = 2;
using
DeviceElementwisePermuteInstance
=
using
DeviceElementwisePermuteInstance
=
ck
::
tensor_operation
::
device
::
DeviceElementwiseImpl
<
ck
::
Tuple
<
ADataType
>
,
ck
::
tensor_operation
::
device
::
DeviceElementwiseImpl
<
ck
::
Tuple
<
ADataType
>
,
// InDataTypeTuple
ck
::
Tuple
<
BDataType
>
,
ck
::
Tuple
<
BDataType
>
,
// OutDataTypeTuple
PassThrough
,
PassThrough
,
// ElementwiseOp
4
,
UnaryOp
,
// UnaryOp
8
,
4
,
// NumDim
ck
::
Sequence
<
8
>
,
8
,
// MPerThread
ck
::
Sequence
<
1
>>
;
2
,
// ScalarMult (alpha)
ck
::
Sequence
<
8
>
,
// InScalarPerVectorSeq
template
<
typename
HostTensorA
,
typename
HostTensorB
,
typename
Functor
>
ck
::
Sequence
<
1
>>
;
// OutScalarPerVectorSeq
void
host_elementwise4D
(
HostTensorB
&
B_nhwc
,
const
HostTensorA
&
A_nchw
,
Functor
functor
)
template
<
typename
HostTensorA
,
typename
HostTensorB
,
typename
FunctorA
,
typename
FunctorB
>
void
host_elementwise4D
(
HostTensorB
&
B_nhwc
,
const
HostTensorA
&
A_nchw
,
FunctorA
functor_a
,
FunctorB
functor_b
)
{
{
for
(
std
::
size_t
n
=
0
;
n
<
A_nchw
.
mDesc
.
GetLengths
()[
0
];
++
n
)
for
(
std
::
size_t
n
=
0
;
n
<
A_nchw
.
mDesc
.
GetLengths
()[
0
];
++
n
)
for
(
std
::
size_t
c
=
0
;
c
<
A_nchw
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
std
::
size_t
c
=
0
;
c
<
A_nchw
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
std
::
size_t
h
=
0
;
h
<
A_nchw
.
mDesc
.
GetLengths
()[
2
];
++
h
)
for
(
std
::
size_t
h
=
0
;
h
<
A_nchw
.
mDesc
.
GetLengths
()[
2
];
++
h
)
for
(
std
::
size_t
w
=
0
;
w
<
A_nchw
.
mDesc
.
GetLengths
()[
3
];
++
w
)
for
(
std
::
size_t
w
=
0
;
w
<
A_nchw
.
mDesc
.
GetLengths
()[
3
];
++
w
)
{
{
ADataType
tmp_val
;
auto
a_val
=
A_nchw
(
n
,
c
,
h
,
w
);
auto
a_val
=
A_nchw
(
n
,
c
,
h
,
w
);
functor
(
B_nhwc
(
n
,
h
,
w
,
c
),
a_val
);
functor_b
(
tmp_val
,
a_val
);
functor_a
(
B_nhwc
(
n
,
h
,
w
,
c
),
2
*
tmp_val
);
}
}
}
}
...
@@ -74,7 +84,7 @@ int main()
...
@@ -74,7 +84,7 @@ int main()
auto
broadcastPermute
=
DeviceElementwisePermuteInstance
{};
auto
broadcastPermute
=
DeviceElementwisePermuteInstance
{};
auto
argument
=
broadcastPermute
.
MakeArgumentPointer
(
auto
argument
=
broadcastPermute
.
MakeArgumentPointer
(
ab_lengths
,
{
a_strides
},
{
b_strides
},
input
,
output
,
PassThrough
{});
ab_lengths
,
{
a_strides
},
{
b_strides
},
input
,
output
,
PassThrough
{}
,
UnaryOp
{}
);
if
(
!
broadcastPermute
.
IsSupportedArgument
(
argument
.
get
()))
if
(
!
broadcastPermute
.
IsSupportedArgument
(
argument
.
get
()))
{
{
...
@@ -106,7 +116,7 @@ int main()
...
@@ -106,7 +116,7 @@ int main()
{
{
b_device_buf
.
FromDevice
(
b
.
mData
.
data
());
b_device_buf
.
FromDevice
(
b
.
mData
.
data
());
Tensor
<
BDataType
>
host_b
(
nhwc
);
Tensor
<
BDataType
>
host_b
(
nhwc
);
host_elementwise4D
(
host_b
,
a
,
PassThrough
{});
host_elementwise4D
(
host_b
,
a
,
PassThrough
{}
,
UnaryOp
{}
);
pass
&=
pass
&=
ck
::
utils
::
check_err
(
b
.
mData
,
host_b
.
mData
,
"Error: Incorrect results b"
,
1e-3
,
1e-3
);
ck
::
utils
::
check_err
(
b
.
mData
,
host_b
.
mData
,
"Error: Incorrect results b"
,
1e-3
,
1e-3
);
...
...
example/65_hip_tensor_permute/elementwise_permute_4D_fp32_ht.cpp
View file @
c6b98c98
...
@@ -18,19 +18,19 @@ using ADataType = F32;
...
@@ -18,19 +18,19 @@ using ADataType = F32;
using
BDataType
=
F32
;
using
BDataType
=
F32
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
;
using
UnaryOp
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
;
// ck::index_t scalar_mult = 2;
// ck::index_t scalar_mult = 2;
using
DeviceElementwisePermuteInstance
=
using
DeviceElementwisePermuteInstance
=
ck
::
tensor_operation
::
device
::
DeviceElementwiseImpl
<
ck
::
Tuple
<
ADataType
>
,
ck
::
tensor_operation
::
device
::
DeviceElementwiseImpl
<
ck
::
Tuple
<
ADataType
>
,
// InDataTypeTuple
ck
::
Tuple
<
BDataType
>
,
ck
::
Tuple
<
BDataType
>
,
// OutDataTypeTuple
PassThrough
,
PassThrough
,
// ElementwiseOp
Square
,
UnaryOp
,
// UnaryOp
4
,
4
,
// NumDim
8
,
8
,
// MPerThread
2
,
2
,
// ScalarMult (alpha)
ck
::
Sequence
<
8
>
,
ck
::
Sequence
<
8
>
,
// InScalarPerVectorSeq
ck
::
Sequence
<
1
>>
;
ck
::
Sequence
<
1
>>
;
// OutScalarPerVectorSeq
template
<
typename
HostTensorA
,
typename
HostTensorB
,
typename
FunctorA
,
typename
FunctorB
>
template
<
typename
HostTensorA
,
typename
HostTensorB
,
typename
FunctorA
,
typename
FunctorB
>
void
host_elementwise4D
(
HostTensorB
&
B_nhwc
,
void
host_elementwise4D
(
HostTensorB
&
B_nhwc
,
...
@@ -84,7 +84,7 @@ int main()
...
@@ -84,7 +84,7 @@ int main()
auto
broadcastPermute
=
DeviceElementwisePermuteInstance
{};
auto
broadcastPermute
=
DeviceElementwisePermuteInstance
{};
auto
argument
=
broadcastPermute
.
MakeArgumentPointer
(
auto
argument
=
broadcastPermute
.
MakeArgumentPointer
(
ab_lengths
,
{
a_strides
},
{
b_strides
},
input
,
output
,
PassThrough
{},
Square
{});
ab_lengths
,
{
a_strides
},
{
b_strides
},
input
,
output
,
PassThrough
{},
UnaryOp
{});
if
(
!
broadcastPermute
.
IsSupportedArgument
(
argument
.
get
()))
if
(
!
broadcastPermute
.
IsSupportedArgument
(
argument
.
get
()))
{
{
...
@@ -116,7 +116,7 @@ int main()
...
@@ -116,7 +116,7 @@ int main()
{
{
b_device_buf
.
FromDevice
(
b
.
mData
.
data
());
b_device_buf
.
FromDevice
(
b
.
mData
.
data
());
Tensor
<
BDataType
>
host_b
(
nhwc
);
Tensor
<
BDataType
>
host_b
(
nhwc
);
host_elementwise4D
(
host_b
,
a
,
PassThrough
{},
Square
{});
host_elementwise4D
(
host_b
,
a
,
PassThrough
{},
UnaryOp
{});
pass
&=
pass
&=
ck
::
utils
::
check_err
(
b
.
mData
,
host_b
.
mData
,
"Error: Incorrect results b"
,
1e-3
,
1e-3
);
ck
::
utils
::
check_err
(
b
.
mData
,
host_b
.
mData
,
"Error: Incorrect results b"
,
1e-3
,
1e-3
);
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
c6b98c98
...
@@ -278,8 +278,8 @@ struct UnarySquare
...
@@ -278,8 +278,8 @@ struct UnarySquare
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same_v
<
T
,
float
>
||
is_same_v
<
T
,
double
>
||
is_same_v
<
T
,
int32_t
>
||
static_assert
(
is_same_v
<
T
,
float
>
||
is_same_v
<
T
,
half_t
>
||
is_same_v
<
T
,
double
>
||
is_same_v
<
T
,
int8_t
>
is_same_v
<
T
,
int32_t
>
||
is_same_v
<
T
,
int8_t
>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
is_same_v
<
T
,
int4_t
>
||
is_same_v
<
T
,
int4_t
>
#endif
#endif
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment