Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
98109c8b
Unverified
Commit
98109c8b
authored
Sep 05, 2023
by
Chao Liu
Committed by
GitHub
Sep 05, 2023
Browse files
add softmax example (#6)
parent
0e92deb7
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
408 additions
and
23 deletions
+408
-23
example/91_tile_program/CMakeLists.txt
example/91_tile_program/CMakeLists.txt
+1
-0
example/91_tile_program/reduce.cpp
example/91_tile_program/reduce.cpp
+2
-2
example/91_tile_program/softmax.cpp
example/91_tile_program/softmax.cpp
+122
-0
example/91_tile_program/softmax.hpp
example/91_tile_program/softmax.hpp
+228
-0
include/ck/utility/amd_warp_shuffle.hpp
include/ck/utility/amd_warp_shuffle.hpp
+17
-7
include/ck/utility/bit_cast.hpp
include/ck/utility/bit_cast.hpp
+19
-0
include/ck/utility/common_header.hpp
include/ck/utility/common_header.hpp
+1
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+2
-0
include/ck/utility/math.hpp
include/ck/utility/math.hpp
+16
-5
include/ck/utility/type.hpp
include/ck/utility/type.hpp
+0
-9
No files found.
example/91_tile_program/CMakeLists.txt
View file @
98109c8b
...
@@ -3,3 +3,4 @@ add_example_executable(example_im2col im2col.cpp)
...
@@ -3,3 +3,4 @@ add_example_executable(example_im2col im2col.cpp)
add_example_executable
(
example_gemm gemm.cpp
)
add_example_executable
(
example_gemm gemm.cpp
)
add_example_executable
(
example_gemm_gemm gemm_gemm.cpp
)
add_example_executable
(
example_gemm_gemm gemm_gemm.cpp
)
add_example_executable
(
example_reduce reduce.cpp
)
add_example_executable
(
example_reduce reduce.cpp
)
add_example_executable
(
example_softmax softmax.cpp
)
example/91_tile_program/reduce.cpp
View file @
98109c8b
...
@@ -55,8 +55,8 @@ int main(int argc, char* argv[])
...
@@ -55,8 +55,8 @@ int main(int argc, char* argv[])
std
::
array
<
ck
::
index_t
,
2
>
a_lengths
{
M
,
N
};
std
::
array
<
ck
::
index_t
,
2
>
a_lengths
{
M
,
N
};
std
::
array
<
ck
::
index_t
,
2
>
a_strides
{
N
,
1
};
std
::
array
<
ck
::
index_t
,
2
>
a_strides
{
N
,
1
};
std
::
array
<
ck
::
index_t
,
2
>
b_lengths
{
M
};
std
::
array
<
ck
::
index_t
,
1
>
b_lengths
{
M
};
std
::
array
<
ck
::
index_t
,
2
>
b_strides
{
1
};
std
::
array
<
ck
::
index_t
,
1
>
b_strides
{
1
};
// host verify
// host verify
Tensor
<
ADataType
>
a_host
(
a_lengths
,
a_strides
);
Tensor
<
ADataType
>
a_host
(
a_lengths
,
a_strides
);
...
...
example/91_tile_program/softmax.cpp
0 → 100644
View file @
98109c8b
#include <cstring>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor/tensor_view.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "softmax.hpp"
template
<
typename
ADataType
,
typename
AccDataType
,
typename
BDataType
>
void
reference_softmax
(
const
Tensor
<
ADataType
>&
a_m_n
,
Tensor
<
BDataType
>&
b_m_n
)
{
auto
f
=
[
&
](
auto
m
)
{
const
int
N
=
a_m_n
.
mDesc
.
GetLengths
()[
1
];
AccDataType
v_max
=
ck
::
NumericLimits
<
ADataType
>::
Lowest
();
// max
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
const
ADataType
v_a
=
a_m_n
(
m
,
n
);
v_max
=
v_max
<
v_a
?
v_a
:
v_max
;
}
AccDataType
v_exp_sum
=
0
;
// sum
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
const
ADataType
v_a
=
a_m_n
(
m
,
n
);
v_exp_sum
+=
ck
::
math
::
exp
(
v_a
-
v_max
);
}
// elementwise
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
const
ADataType
v_a
=
a_m_n
(
m
,
n
);
b_m_n
(
m
,
n
)
=
ck
::
math
::
exp
(
v_a
-
v_max
)
/
v_exp_sum
;
}
};
make_ParallelTensorFunctor
(
f
,
b_m_n
.
mDesc
.
GetLengths
()[
0
])(
std
::
thread
::
hardware_concurrency
());
}
int
main
(
int
argc
,
char
*
argv
[])
{
using
ADataType
=
float
;
using
AccDataType
=
float
;
using
BDataType
=
float
;
ck
::
index_t
M
=
3328
;
ck
::
index_t
N
=
4096
;
if
(
argc
==
3
)
{
M
=
std
::
stoi
(
argv
[
1
]);
N
=
std
::
stoi
(
argv
[
2
]);
}
std
::
array
<
ck
::
index_t
,
2
>
a_lengths
{
M
,
N
};
std
::
array
<
ck
::
index_t
,
2
>
a_strides
{
N
,
1
};
std
::
array
<
ck
::
index_t
,
2
>
b_lengths
{
M
,
N
};
std
::
array
<
ck
::
index_t
,
2
>
b_strides
{
N
,
1
};
// host verify
Tensor
<
ADataType
>
a_host
(
a_lengths
,
a_strides
);
Tensor
<
BDataType
>
b_host_ref
(
b_lengths
,
b_strides
);
Tensor
<
BDataType
>
b_host_dev
(
b_lengths
,
b_strides
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_host
);
// reference
reference_softmax
<
ADataType
,
AccDataType
,
BDataType
>
(
a_host
,
b_host_ref
);
DeviceMem
a_buf
(
sizeof
(
ADataType
)
*
a_host
.
GetElementSpaceSize
());
DeviceMem
b_buf
(
sizeof
(
BDataType
)
*
b_host_ref
.
GetElementSpaceSize
());
a_buf
.
ToDevice
(
a_host
.
mData
.
data
());
constexpr
ck
::
index_t
kMPerBlock
=
128
;
constexpr
ck
::
index_t
kNPerBlock
=
128
;
constexpr
ck
::
index_t
kBlockSize
=
256
;
ck
::
index_t
kGridSize
=
(
M
/
kMPerBlock
);
std
::
cout
<<
"grid size "
<<
kGridSize
<<
std
::
endl
;
const
auto
kernel
=
Softmax
<
ADataType
,
AccDataType
,
BDataType
,
kBlockSize
,
kMPerBlock
,
kNPerBlock
>
{};
float
ave_time
=
launch
(
ProgramServer
{},
kernel
,
kGridSize
,
kBlockSize
,
static_cast
<
ADataType
*>
(
a_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_buf
.
GetDeviceBuffer
()),
M
,
N
);
b_buf
.
FromDevice
(
b_host_dev
.
mData
.
data
());
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
N
+
sizeof
(
BDataType
)
*
M
*
N
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
return
!
ck
::
utils
::
check_err
(
b_host_dev
,
b_host_ref
);
}
example/91_tile_program/softmax.hpp
0 → 100644
View file @
98109c8b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "tile_program.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_window.hpp"
#include "ck/tile_program/tile/load_tile.hpp"
#include "ck/tile_program/tile/store_tile.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/block_tile/block_reduce.hpp"
template
<
typename
ADataType
,
typename
AccDataType
,
typename
BDataType
,
ck
::
index_t
kBlockSize
,
ck
::
index_t
kMPerBlock
,
ck
::
index_t
kNPerBlock
>
struct
Softmax
{
#if 0
__host__ __device__ static constexpr auto MakeABlockTileDistribution()
{
using namespace ck;
using namespace ck::tile_program;
// 2x2 wave
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<>,
Tuple<Sequence<2, 2, 4, 2, 4>, Sequence<2, 2, 32>>,
Tuple<Sequence<1, 2>, Sequence<1, 2>>,
Tuple<Sequence<1, 1>, Sequence<3, 2>>,
Sequence<1, 2, 1, 1>,
Sequence<0, 0, 2, 4>>{});
}
#elif
0
__host__
__device__
static
constexpr
auto
MakeABlockTileDistribution
()
{
using
namespace
ck
;
using
namespace
ck
::
tile_program
;
// 2x2 wave
return
make_static_tile_distribution
(
StaticTileDistributionEncoding
<
Sequence
<>
,
Tuple
<
Sequence
<
2
,
2
,
32
>
,
Sequence
<
2
,
2
,
4
,
2
,
4
>>
,
Tuple
<
Sequence
<
2
,
1
>
,
Sequence
<
2
,
1
>>
,
Tuple
<
Sequence
<
1
,
1
>
,
Sequence
<
3
,
2
>>
,
Sequence
<
2
,
1
,
2
,
2
>
,
Sequence
<
0
,
0
,
2
,
4
>>
{});
}
#elif 1
__host__
__device__
static
constexpr
auto
MakeABlockTileDistribution
()
{
using
namespace
ck
;
using
namespace
ck
::
tile_program
;
// 4x1 wave
return
make_static_tile_distribution
(
StaticTileDistributionEncoding
<
Sequence
<>
,
Tuple
<
Sequence
<
1
,
4
,
4
,
2
,
4
>
,
Sequence
<
4
,
1
,
32
>>
,
Tuple
<
Sequence
<
1
,
2
>
,
Sequence
<
1
,
2
>>
,
Tuple
<
Sequence
<
1
,
1
>
,
Sequence
<
3
,
2
>>
,
Sequence
<
1
,
2
,
1
,
1
>
,
Sequence
<
0
,
0
,
2
,
4
>>
{});
}
#endif
__host__
__device__
void
operator
()(
ProgramServer
&
ps
,
const
ADataType
*
p_a
,
BDataType
*
p_b
,
ck
::
index_t
M
,
ck
::
index_t
N
)
const
{
using
namespace
ck
;
using
namespace
ck
::
tile_program
;
using
namespace
ck
::
tile_program
::
block
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
const
auto
a_m_n
=
make_naive_tensor_view
<
AddressSpaceEnum
::
Global
>
(
p_a
,
make_tuple
(
M
,
N
),
make_tuple
(
N
,
1
),
Number
<
32
>
{},
Number
<
1
>
{});
const
auto
iM
=
ps
.
get_block_id
()
*
kMPerBlock
;
// A window
auto
a_block_window
=
make_tile_window
(
a_m_n
,
make_tuple
(
Number
<
kMPerBlock
>
{},
Number
<
kNPerBlock
>
{}),
{
iM
,
0
},
MakeABlockTileDistribution
());
constexpr
auto
reduce_dims
=
Sequence
<
1
>
{};
const
auto
f_max
=
[](
auto
v0
,
auto
v1
)
{
return
max
(
v0
,
v1
);
};
const
ADataType
max_reduce_init_value
=
NumericLimits
<
ADataType
>::
Lowest
();
// max = max(a)
auto
max_block_tensor
=
decltype
(
block_tile_reduce
<
AccDataType
>
(
load_tile
(
a_block_window
),
reduce_dims
,
f_max
,
max_reduce_init_value
)){};
tile_elementwise_inout
(
[
&
](
auto
&
max
)
{
max
=
type_convert
<
AccDataType
>
(
max_reduce_init_value
);
},
max_block_tensor
);
index_t
iN
=
0
;
do
{
const
auto
a_block_tensor
=
load_tile
(
a_block_window
);
block_tile_reduce
(
max_block_tensor
,
a_block_tensor
,
reduce_dims
,
f_max
);
move_tile_window
(
a_block_window
,
{
0
,
kNPerBlock
});
iN
+=
kNPerBlock
;
}
while
(
iN
<
N
);
// cross lane reduce: max
block_tile_reduce_sync
(
max_block_tensor
,
f_max
);
// exp_sum = sum(exp(a - a_max))
auto
exp_sum_block_tensor
=
make_static_distributed_tensor
<
AccDataType
>
(
max_block_tensor
.
GetTileDistribution
());
tile_elementwise_inout
([
&
](
auto
&
exp_sum
)
{
exp_sum
=
0
;
},
exp_sum_block_tensor
);
// reset window location
iN
=
0
;
move_tile_window
(
a_block_window
,
{
0
,
-
N
});
do
{
const
auto
a_block_tensor
=
load_tile
(
a_block_window
);
constexpr
auto
a_spans
=
decltype
(
a_block_tensor
)
::
GetDistributedSpans
();
//
sweep_tile_span
(
a_spans
[
I0
],
[
&
](
auto
idx0
)
{
constexpr
auto
m_idx
=
make_tuple
(
idx0
);
const
auto
v_max
=
max_block_tensor
.
GetElementFromTileDistributedIndices
(
m_idx
);
AccDataType
v_exp_sum
=
exp_sum_block_tensor
.
GetElementFromTileDistributedIndices
(
m_idx
);
sweep_tile_span
(
a_spans
[
I1
],
[
&
](
auto
idx1
)
{
constexpr
auto
m_n_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
v_a
=
a_block_tensor
.
GetElementFromTileDistributedIndices
(
m_n_idx
);
(
void
)
v_max
;
// exp and sum
v_exp_sum
+=
math
::
exp
(
v_a
-
v_max
);
});
exp_sum_block_tensor
.
SetElementFromTileDistributedIndices
(
m_idx
,
v_exp_sum
);
});
move_tile_window
(
a_block_window
,
{
0
,
kNPerBlock
});
iN
+=
kNPerBlock
;
}
while
(
iN
<
N
);
// cross lane reduce: sum
block_tile_reduce_sync
(
exp_sum_block_tensor
,
[](
auto
v0
,
auto
v1
)
{
return
v0
+
v1
;
});
// B
const
auto
b_m_n
=
make_naive_tensor_view
<
AddressSpaceEnum
::
Global
>
(
p_b
,
make_tuple
(
M
,
N
),
make_tuple
(
N
,
1
),
Number
<
32
>
{},
Number
<
1
>
{});
// B window
auto
b_block_window
=
make_tile_window
(
b_m_n
,
make_tuple
(
Number
<
kMPerBlock
>
{},
Number
<
kNPerBlock
>
{}),
{
iM
,
0
});
// reset window location
iN
=
0
;
move_tile_window
(
a_block_window
,
{
0
,
-
N
});
do
{
const
auto
a_block_tensor
=
load_tile
(
a_block_window
);
constexpr
auto
a_spans
=
decltype
(
a_block_tensor
)
::
GetDistributedSpans
();
auto
b_block_tensor
=
make_static_distributed_tensor
<
BDataType
>
(
a_block_tensor
.
GetTileDistribution
());
//
sweep_tile_span
(
a_spans
[
I0
],
[
&
](
auto
idx0
)
{
constexpr
auto
m_idx
=
make_tuple
(
idx0
);
const
auto
v_max
=
max_block_tensor
.
GetElementFromTileDistributedIndices
(
m_idx
);
const
auto
v_exp_sum
=
exp_sum_block_tensor
.
GetElementFromTileDistributedIndices
(
m_idx
);
sweep_tile_span
(
a_spans
[
I1
],
[
&
](
auto
idx1
)
{
constexpr
auto
m_n_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
v_a
=
a_block_tensor
.
GetElementFromTileDistributedIndices
(
m_n_idx
);
// exp
const
BDataType
v_b
=
type_convert
<
BDataType
>
(
math
::
exp
(
v_a
-
v_max
)
/
v_exp_sum
);
b_block_tensor
.
SetElementFromTileDistributedIndices
(
m_n_idx
,
v_b
);
});
});
// store B tile
store_tile
(
b_block_window
,
b_block_tensor
);
move_tile_window
(
a_block_window
,
{
0
,
kNPerBlock
});
move_tile_window
(
b_block_window
,
{
0
,
kNPerBlock
});
iN
+=
kNPerBlock
;
}
while
(
iN
<
N
);
}
};
include/ck/utility/amd_warp_shuffle.hpp
View file @
98109c8b
...
@@ -8,24 +8,34 @@
...
@@ -8,24 +8,34 @@
namespace
ck
{
namespace
ck
{
template
<
typename
T
>
template
<
typename
T
>
__device__
T
warp_shuffle_up
(
const
T
&
v
ar
,
uint32_t
delta
)
__device__
T
warp_shuffle_up
(
const
T
&
v
_local
,
uint32_t
lane_
delta
)
{
{
#if 0
#if 0
return __shfl_up(v
ar,
delta);
return __shfl_up(v
_local, lane_
delta);
#elif
1
#elif
1
const
uint32_t
wrap_around_delta
=
warpSize
-
delta
;
static_assert
(
sizeof
(
T
)
==
sizeof
(
int32_t
),
"wrong!"
)
;
return
__builtin_amdgcn_ds_bpermute
((
__lane_id
()
<<
2
)
+
(
wrap_around_delta
<<
2
),
var
);
const
uint32_t
wrap_around_lane_delta
=
warpSize
-
lane_delta
;
const
int32_t
v_remote_tmp
=
__builtin_amdgcn_ds_bpermute
(
(
__lane_id
()
<<
2
)
+
(
wrap_around_lane_delta
<<
2
),
bit_cast
<
int32_t
>
(
v_local
));
return
bit_cast
<
T
>
(
v_remote_tmp
);
#endif
#endif
}
}
template
<
typename
T
>
template
<
typename
T
>
__device__
T
warp_shuffle_down
(
const
T
&
v
ar
,
uint32_t
delta
)
__device__
T
warp_shuffle_down
(
const
T
&
v
_local
,
uint32_t
lane_
delta
)
{
{
#if 0
#if 0
return __shfl_down(v
ar,
delta);
return __shfl_down(v
_local, lane_
delta);
#elif
1
#elif
1
return
__builtin_amdgcn_ds_bpermute
((
__lane_id
()
<<
2
)
+
(
delta
<<
2
),
var
);
static_assert
(
sizeof
(
T
)
==
sizeof
(
int32_t
),
"wrong!"
);
const
int32_t
v_remote_tmp
=
__builtin_amdgcn_ds_bpermute
(
(
__lane_id
()
<<
2
)
+
(
lane_delta
<<
2
),
bit_cast
<
int32_t
>
(
v_local
));
return
bit_cast
<
T
>
(
v_remote_tmp
);
#endif
#endif
}
}
...
...
include/ck/utility/bit_cast.hpp
0 → 100644
View file @
98109c8b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/enable_if.hpp"
namespace
ck
{
template
<
typename
Y
,
typename
X
,
typename
enable_if
<
sizeof
(
X
)
==
sizeof
(
Y
),
bool
>
::
type
=
false
>
__host__
__device__
constexpr
Y
bit_cast
(
const
X
&
x
)
{
static_assert
(
__has_builtin
(
__builtin_bit_cast
),
""
);
return
__builtin_bit_cast
(
Y
,
x
);
}
}
// namespace ck
include/ck/utility/common_header.hpp
View file @
98109c8b
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include "ck/utility/static_assert.hpp"
#include "ck/utility/static_assert.hpp"
#include "ck/utility/remove_cvref.hpp"
#include "ck/utility/remove_cvref.hpp"
#include "ck/utility/is_static.hpp"
#include "ck/utility/is_static.hpp"
#include "ck/utility/bit_cast.hpp"
#include "ck/utility/print.hpp"
#include "ck/utility/print.hpp"
#include "ck/utility/array.hpp"
#include "ck/utility/array.hpp"
#include "ck/utility/container_helper.hpp"
#include "ck/utility/container_helper.hpp"
...
...
include/ck/utility/data_type.hpp
View file @
98109c8b
...
@@ -3,6 +3,8 @@
...
@@ -3,6 +3,8 @@
#pragma once
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/bit_cast.hpp"
#include "ck/utility/statically_indexed_array.hpp"
#include "ck/utility/statically_indexed_array.hpp"
namespace
ck
{
namespace
ck
{
...
...
include/ck/utility/math.hpp
View file @
98109c8b
...
@@ -150,27 +150,38 @@ __host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T&
...
@@ -150,27 +150,38 @@ __host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T&
return
min
(
max
(
x
,
lowerbound
),
upperbound
);
return
min
(
max
(
x
,
lowerbound
),
upperbound
);
}
}
// disallow implicit type casting
// prevent implicit type casting
template
<
typename
T
>
__host__
T
exp
(
T
x
);
template
<
typename
T
>
template
<
typename
T
>
__device__
T
exp
(
T
x
);
__device__
T
exp
(
T
x
);
// TODO: add f16 support using v_exp_f16
// TODO: add f16 support using v_exp_f16
template
<
>
template
<
>
__device__
float
exp
<
float
>
(
float
x
)
inline
__device__
float
exp
<
float
>
(
float
x
)
{
{
return
__expf
(
x
);
return
__expf
(
x
);
}
}
template
<
>
template
<
>
__device__
double
exp
<
double
>
(
double
x
)
inline
__device__
double
exp
<
double
>
(
double
x
)
{
{
return
exp
(
x
);
return
exp
(
x
);
}
}
static
inline
__host__
float
exp
(
float
x
)
{
return
std
::
expf
(
x
);
}
template
<
>
inline
__host__
float
exp
<
float
>
(
float
x
)
{
return
std
::
expf
(
x
);
}
static
inline
__host__
double
exp
(
double
x
)
{
return
std
::
exp
(
x
);
}
template
<
>
inline
__host__
double
exp
<
double
>
(
double
x
)
{
return
std
::
exp
(
x
);
}
// greatest common divisor, aka highest common factor
// greatest common divisor, aka highest common factor
__host__
__device__
constexpr
index_t
gcd
(
index_t
x
,
index_t
y
)
__host__
__device__
constexpr
index_t
gcd
(
index_t
x
,
index_t
y
)
...
...
include/ck/utility/type.hpp
View file @
98109c8b
...
@@ -32,13 +32,4 @@ inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
...
@@ -32,13 +32,4 @@ inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
template
<
typename
T
>
template
<
typename
T
>
inline
constexpr
bool
is_empty_v
=
std
::
is_empty
<
T
>::
value
;
inline
constexpr
bool
is_empty_v
=
std
::
is_empty
<
T
>::
value
;
// bit_cast
template
<
typename
Y
,
typename
X
,
typename
enable_if
<
sizeof
(
X
)
==
sizeof
(
Y
),
bool
>
::
type
=
false
>
__host__
__device__
constexpr
Y
bit_cast
(
const
X
&
x
)
{
static_assert
(
__has_builtin
(
__builtin_bit_cast
),
""
);
return
__builtin_bit_cast
(
Y
,
x
);
}
}
// namespace ck
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment