Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
667047b9
Commit
667047b9
authored
Sep 06, 2024
by
carlushuang
Browse files
topk-softmax
parent
840cba8e
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
896 additions
and
125 deletions
+896
-125
example/ck_tile/05_moe/fused_moe_interface.cpp
example/ck_tile/05_moe/fused_moe_interface.cpp
+0
-0
example/ck_tile/05_moe/fused_moe_interface.hpp
example/ck_tile/05_moe/fused_moe_interface.hpp
+22
-0
example/ck_tile/05_moe/topk_softmax_api.cpp
example/ck_tile/05_moe/topk_softmax_api.cpp
+50
-0
example/ck_tile/05_moe/topk_softmax_api.hpp
example/ck_tile/05_moe/topk_softmax_api.hpp
+21
-0
include/ck_tile/core/arch/utility.hpp
include/ck_tile/core/arch/utility.hpp
+43
-0
include/ck_tile/core/numeric/math.hpp
include/ck_tile/core/numeric/math.hpp
+2
-1
include/ck_tile/host/reference/reference_softmax.hpp
include/ck_tile/host/reference/reference_softmax.hpp
+58
-20
include/ck_tile/host/reference/reference_topk.hpp
include/ck_tile/host/reference/reference_topk.hpp
+21
-0
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
.../ck_tile/ops/elementwise/unary_element_wise_operation.hpp
+105
-103
include/ck_tile/ops/reduce/block/block_reduce.hpp
include/ck_tile/ops/reduce/block/block_reduce.hpp
+89
-1
include/ck_tile/ops/softmax.hpp
include/ck_tile/ops/softmax.hpp
+8
-0
include/ck_tile/ops/softmax/block/block_softmax_2d.hpp
include/ck_tile/ops/softmax/block/block_softmax_2d.hpp
+80
-0
include/ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp
...de/ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp
+16
-0
include/ck_tile/ops/topk.hpp
include/ck_tile/ops/topk.hpp
+8
-0
include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp
include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp
+112
-0
include/ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp
...e/ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp
+22
-0
include/ck_tile/ops/topk_softmax.hpp
include/ck_tile/ops/topk_softmax.hpp
+10
-0
include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp
...e/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp
+117
-0
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
...k_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
+53
-0
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp
...opk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp
+59
-0
No files found.
example/ck_tile/05_moe/fused_moe_interface.cpp
0 → 100644
View file @
667047b9
example/ck_tile/05_moe/fused_moe_interface.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
struct
fused_moe_traits
{
int
hdim_q
;
int
hdim_v
;
std
::
string
data_type
;
bool
is_group_mode
;
bool
is_v_rowmajor
;
mask_enum
mask_type
;
bias_enum
bias_type
;
// 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool
has_lse
;
bool
has_dropout
;
bool
do_fp8_static_quant
;
// TODO: padding check is inside this api
};
float
fused_moe
(
fused_moe_traits
,
fused_moe_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/05_moe/topk_softmax_api.cpp
0 → 100644
View file @
667047b9
#include "topk_softmax_api.hpp"
float
topk_softmax
(
topk_softmax_trait
t
,
topk_softmax_kargs
a
,
ck_tile
::
stream_config
s
)
{
if
(
t
.
input_type
==
"fp16"
&&
t
.
weight_type
==
"fp32"
)
{
using
ts_input_type
=
ck_tile
::
fp16_t
;
using
ts_weight_type
=
float
;
using
ts_index_type
=
ck_tile
::
index_t
;
constexpr
ck_tile
::
index_t
ts_experts
=
8
;
using
ts_problem
=
ck_tile
::
TopkSoftmaxWarpPerRowProblem
<
ts_input_type
,
ts_weight_type
,
ts_index_type
,
ts_experts
>
;
using
ts_pipeline
=
ck_tile
::
TopkSoftmaxWarpPerRowPipeline
<
ts_problem
>
;
using
kernel
=
ck_tile
::
TopkSoftmaxKernel
<
ts_pipeline
>
;
auto
kargs
=
kernel
::
MakeKargs
(
a
);
const
dim3
grids
=
kernel
::
GridSize
(
a
);
constexpr
dim3
blocks
=
kernel
::
BlockSize
();
float
ave_time
=
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
1
>
(
kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
}
else
if
(
t
.
input_type
==
"bf16"
&&
t
.
weight_type
==
"fp32"
)
{
using
ts_input_type
=
ck_tile
::
bf16_t
;
using
ts_weight_type
=
float
;
using
ts_index_type
=
ck_tile
::
index_t
;
constexpr
ck_tile
::
index_t
ts_experts
=
8
;
using
ts_problem
=
ck_tile
::
TopkSoftmaxWarpPerRowProblem
<
ts_input_type
,
ts_weight_type
,
ts_index_type
,
ts_experts
>
;
using
ts_pipeline
=
ck_tile
::
TopkSoftmaxWarpPerRowPipeline
<
ts_problem
>
;
using
kernel
=
ck_tile
::
TopkSoftmaxKernel
<
ts_pipeline
>
;
auto
kargs
=
kernel
::
MakeKargs
(
a
);
const
dim3
grids
=
kernel
::
GridSize
(
a
);
constexpr
dim3
blocks
=
kernel
::
BlockSize
();
float
ave_time
=
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
1
>
(
kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
}
return
-
1
;
}
example/ck_tile/05_moe/topk_softmax_api.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/topk_softmax.hpp"
#include <string>
struct
topk_softmax_trait
{
std
::
string
input_type
;
std
::
string
weight_type
;
// currently always float
int
experts
;
};
struct
topk_softmax_kargs
:
public
ck_tile
::
TopkSoftmaxHostArgs
{
};
float
topk_softmax
(
topk_softmax_trait
t
,
topk_softmax_kargs
a
,
ck_tile
::
stream_config
s
);
include/ck_tile/core/arch/utility.hpp
View file @
667047b9
...
...
@@ -120,4 +120,47 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
#endif
}
template
<
typename
T
>
CK_TILE_DEVICE
T
warp_shuffle
(
const
T
&
v_local
,
uint32_t
src_lane
)
{
#if 0
return __shfl(v_local, src_lane);
#elif
1
if
constexpr
(
sizeof
(
int32_t
)
>
sizeof
(
T
))
{
union
packet
{
int32_t
x
;
T
v
;
};
packet
p
;
p
.
v
=
v_local
;
packet
p_remote
;
p_remote
.
x
=
__builtin_amdgcn_ds_bpermute
(
src_lane
<<
2
,
bit_cast
<
int32_t
>
(
p
));
return
p_remote
.
v
;
}
else
if
constexpr
(
sizeof
(
int32_t
)
==
sizeof
(
T
))
{
const
int32_t
v_remote_tmp
=
__builtin_amdgcn_ds_bpermute
(
src_lane
<<
2
,
bit_cast
<
int32_t
>
(
v_local
));
return
bit_cast
<
T
>
(
v_remote_tmp
);
}
else
{
static_assert
(
sizeof
(
T
)
%
sizeof
(
int32_t
)
==
0
,
"wrong!"
);
constexpr
index_t
elm
=
sizeof
(
T
)
/
sizeof
(
int32_t
);
using
vector_type
=
thread_buffer
<
int32_t
,
elm
>
;
auto
vs
=
bit_cast
<
vector_type
>
(
v_local
);
auto
vs_remote
=
vector_type
{};
static_for
<
0
,
elm
,
1
>
{}([
&
](
auto
i_e
)
{
int32_t
tmp
=
__builtin_amdgcn_ds_bpermute
(
src_lane
<<
2
,
bit_cast
<
int32_t
>
(
vs
[
i_e
]));
vs_remote
(
i_e
)
=
tmp
;
});
return
bit_cast
<
T
>
(
vs_remote
);
}
#endif
}
}
// namespace ck_tile
include/ck_tile/core/numeric/math.hpp
View file @
667047b9
...
...
@@ -1406,7 +1406,8 @@ CK_TILE_DEVICE T rcp(T x)
#if !CK_TILE_WORKAROUND_SWDEV_383542
return
__frcp_rn
(
x
);
#else
return
__ocml_native_recip_f32
(
x
);
// return __ocml_native_recip_f32(x);
return
__builtin_amdgcn_rcpf
(
x
);
#endif
};
...
...
include/ck_tile/host/reference/reference_softmax.hpp
View file @
667047b9
...
...
@@ -9,43 +9,81 @@
namespace
ck_tile
{
template
<
typename
AData
Type
,
typename
AccData
Type
,
typename
BData
Type
>
CK_TILE_HOST
void
reference_softmax
(
const
HostTensor
<
ADataType
>&
a_m_n
,
HostTensor
<
BData
Type
>&
b_m_n
)
template
<
typename
Input
Type
,
typename
Compute
Type
,
typename
OutputType
=
Compute
Type
>
CK_TILE_HOST
void
reference_softmax
(
const
HostTensor
<
InputType
>&
x
,
HostTensor
<
Output
Type
>&
y
,
index_t
dim
=
-
1
)
{
auto
f
=
[
&
](
auto
m
)
{
const
int
N
=
a_m_n
.
mDesc
.
get_lengths
()[
1
];
index_t
rank
=
x
.
get_num_of_dimension
();
assert
(
rank
==
y
.
get_num_of_dimension
());
assert
(
dim
==
-
1
||
dim
<
rank
);
AccDataType
v_max
=
ck_tile
::
numeric
<
ADataType
>::
Lowest
();
index_t
target_dim
=
dim
==
-
1
?
(
rank
-
1
)
:
dim
;
index_t
softmax_len
=
x
.
get_length
(
target_dim
);
index_t
n_parallel
=
x
.
get_element_size
()
/
softmax_len
;
auto
x_len
=
x
.
get_lengths
();
// max
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
const
ADataType
v_a
=
a_m_n
(
m
,
n
);
auto
f
=
[
&
](
auto
i_element
)
{
std
::
vector
<
size_t
>
coord
=
[
&
]()
{
std
::
vector
<
size_t
>
t_
(
rank
,
0
);
size_t
r
=
i_element
;
for
(
index_t
i
=
rank
-
1
;
i
>=
0
;
i
--
)
{
if
(
i
==
target_dim
)
continue
;
t_
[
i
]
=
r
%
x_len
[
i
];
r
=
r
/
x_len
[
i
];
}
return
t_
;
}();
ComputeType
v_max
=
-
ck_tile
::
numeric
<
ComputeType
>::
infinity
();
v_max
=
v_max
<
v_a
?
v_a
:
v_max
;
// compute max
for
(
auto
idx
=
0
;
idx
<
softmax_len
;
idx
++
)
{
auto
c_
=
coord
;
c_
[
target_dim
]
=
idx
;
const
ComputeType
v_x
=
ck_tile
::
type_convert
<
ComputeType
>
(
x
(
c_
));
v_max
=
v_max
<
v_x
?
v_x
:
v_max
;
}
AccData
Type
v_exp_sum
=
0
;
Compute
Type
v_exp_sum
=
static_cast
<
ComputeType
>
(
0
)
;
// sum
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
auto
idx
=
0
;
idx
<
softmax_len
;
idx
++
)
{
const
ADataType
v_a
=
a_m_n
(
m
,
n
);
auto
c_
=
coord
;
c_
[
target_dim
]
=
idx
;
v_exp_sum
+=
ck_tile
::
exp
(
v_a
-
v_max
);
const
ComputeType
v_x
=
ck_tile
::
type_convert
<
ComputeType
>
(
x
(
c_
));
v_exp_sum
+=
ck_tile
::
exp
(
v_x
-
v_max
);
}
// elementwise
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
auto
idx
=
0
;
idx
<
softmax_len
;
idx
++
)
{
const
ADataType
v_a
=
a_m_n
(
m
,
n
);
auto
c_
=
coord
;
c_
[
target_dim
]
=
idx
;
const
ComputeType
v_x
=
ck_tile
::
type_convert
<
ComputeType
>
(
x
(
c_
));
auto
out
=
ck_tile
::
exp
(
v_x
-
v_max
)
/
v_exp_sum
;
b_m_n
(
m
,
n
)
=
ck_tile
::
exp
(
v_a
-
v_max
)
/
v_exp_sum
;
y
(
c_
)
=
ck_tile
::
type_convert
<
OutputType
>
(
out
)
;
}
};
make_ParallelTensorFunctor
(
f
,
b_m_n
.
mDesc
.
get_lengths
()[
0
])(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f
,
n_parallel
)(
std
::
thread
::
hardware_concurrency
());
}
template
<
typename
InputType
,
typename
ComputeType
,
typename
OutputType
=
ComputeType
>
CK_TILE_HOST
auto
reference_softmax
(
const
HostTensor
<
InputType
>&
x
,
index_t
dim
=
-
1
)
{
HostTensor
<
OutputType
>
y
(
x
.
get_lengths
(),
x
.
get_strides
());
reference_softmax
<
InputType
,
ComputeType
,
OutputType
>
(
x
,
y
,
dim
);
return
y
;
}
}
// namespace ck_tile
include/ck_tile/host/reference/reference_topk.hpp
View file @
667047b9
...
...
@@ -100,4 +100,25 @@ CK_TILE_HOST void reference_topk(const HostTensor<DataType>& x,
make_ParallelTensorFunctor
(
f
,
n_parallel
)(
std
::
thread
::
hardware_concurrency
());
}
// TODO: if using this method, the return tensor would be dense(no stride)
template
<
typename
DataType
,
typename
IndexType
=
index_t
>
CK_TILE_HOST
auto
reference_topk
(
const
HostTensor
<
DataType
>&
x
,
index_t
k
,
index_t
dim
=
-
1
,
bool
largest
=
true
,
bool
sorted
=
true
)
{
auto
lens
=
x
.
get_lengths
();
index_t
target_dim
=
(
dim
==
-
1
)
?
(
lens
.
size
()
-
1
)
:
dim
;
assert
(
target_dim
<
lens
.
size
());
assert
(
k
<=
lens
[
target_dim
]);
lens
[
target_dim
]
=
k
;
HostTensor
<
DataType
>
y_values
(
lens
);
HostTensor
<
IndexType
>
y_indices
(
lens
);
reference_topk
<
DataType
,
IndexType
>
(
x
,
y_values
,
y_indices
,
k
,
dim
,
largest
,
sorted
);
return
ck_tile
::
make_tuple
(
y_values
,
y_indices
);
}
}
// namespace ck_tile
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
View file @
667047b9
...
...
@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include <type_traits>
namespace
ck_tile
{
namespace
element_wise
{
...
...
@@ -258,10 +259,10 @@ struct ConvertBF16RTN
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(
ck_tile
::is_same<Y, ck_tile::bf16_t>
::value
, "Data type is not supported by this operation!");
static_assert(
std
::is_same
_v
<Y, ck_tile::bf16_t>, "Data type is not supported by this operation!");
// check X datatype
static_assert(
ck_tile
::is_same<X, float>
::value || ck_tile
::is_same<X, ck_tile::fp16_t>
::value
,
static_assert(
std
::is_same
_v
<X, float>
|| std
::is_same
_v
<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = bf16_convert_rtn<Y>(x);
...
...
@@ -275,11 +276,11 @@ struct ConvertF8SR
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(
ck_tile
::is_same<Y, ck_tile::fp8_t>
::value || ck_tile
::is_same<Y, ck_tile::bf8_t>
::value
,
static_assert(
std
::is_same
_v
<Y, ck_tile::fp8_t>
|| std
::is_same
_v
<Y, ck_tile::bf8_t>,
"Data type is not supported by this operation!");
// check X datatype
static_assert(
ck_tile
::is_same<X, float>
::value || ck_tile
::is_same<X, ck_tile::fp16_t>
::value
,
static_assert(
std
::is_same
_v
<X, float>
|| std
::is_same
_v
<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = f8_convert_sr<Y>(x);
...
...
@@ -293,11 +294,11 @@ struct ConvertF8RNE
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(
ck_tile
::is_same<Y, ck_tile::fp8_t>
::value || ck_tile
::is_same<Y, ck_tile::bf8_t>
::value
,
static_assert(
std
::is_same
_v
<Y, ck_tile::fp8_t>
|| std
::is_same
_v
<Y, ck_tile::bf8_t>,
"Data type is not supported by this operation!");
// check X datatype
static_assert(
ck_tile
::is_same<X, float>
::value || ck_tile
::is_same<X, ck_tile::fp16_t>
::value
,
static_assert(
std
::is_same
_v
<X, float>
|| std
::is_same
_v
<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = f8_convert_rne<Y>(x);
...
...
@@ -362,7 +363,7 @@ struct ScaleAndResetNaNToMinusInfinity
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
ck_tile
::
isnan
(
x
)
?
-
ck_tile
::
NumericLimits
<
float
>::
I
nfinity
()
:
scale_
*
x
;
y
=
ck_tile
::
isnan
(
x
)
?
-
numeric
<
float
>::
i
nfinity
()
:
scale_
*
x
;
};
float
scale_
;
...
...
@@ -375,8 +376,8 @@ struct UnaryDivide
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
x
/
type_convert
<
T
>
(
divider_
);
...
...
@@ -390,10 +391,11 @@ struct UnarySquare
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same_v
<
T
,
float
>
||
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
is_same_v
<
T
,
double
>
||
is_same_v
<
T
,
int32_t
>
||
is_same_v
<
T
,
int8_t
>
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
is_same_v
<
T
,
int4_t
>
||
std
::
is_same_v
<
T
,
int4_t
>
#endif
,
"Data type is not supported by this operation!"
);
...
...
@@ -406,9 +408,9 @@ struct UnaryAbs
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same
_v
<
T
,
int8_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
abs
(
x
);
...
...
@@ -420,7 +422,7 @@ struct UnarySqrt
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
sqrt
(
x
);
...
...
@@ -432,9 +434,9 @@ struct Relu
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same
_v
<
T
,
int8_t
>
,
"Data type is not supported by this operation!"
);
y
=
x
>
0
?
x
:
0
;
}
...
...
@@ -597,9 +599,9 @@ struct Sigmoid
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same
_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
one
/
(
one
+
ck_tile
::
exp
(
-
x
));
...
...
@@ -611,9 +613,9 @@ struct Silu
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same_v
<
T
,
float
>
||
is_same_v
<
T
,
double
>
||
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
is_same_v
<
T
,
int8_t
>
||
is_same_v
<
T
,
int32_t
>
,
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
x
*
(
one
/
(
one
+
ck_tile
::
exp
(
-
x
)));
...
...
@@ -625,9 +627,9 @@ struct TanH
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same
_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
tanh
(
x
);
...
...
@@ -639,9 +641,9 @@ struct ACos
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same
_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
acos
(
x
);
...
...
@@ -653,9 +655,9 @@ struct Neg
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same
_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
neg
(
x
);
...
...
@@ -667,9 +669,9 @@ struct ATan
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same
_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
atan
(
x
);
...
...
@@ -681,9 +683,9 @@ struct Sin
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same
_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
sin
(
x
);
...
...
@@ -695,9 +697,9 @@ struct ASinH
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same
_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
asinh
(
x
);
...
...
@@ -709,9 +711,9 @@ struct Cos
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same
_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
cos
(
x
);
...
...
@@ -723,9 +725,9 @@ struct ACosH
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same
_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
acosh
(
x
);
...
...
@@ -737,9 +739,9 @@ struct Tan
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same
_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
tan
(
x
);
...
...
@@ -751,9 +753,9 @@ struct ATanH
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same
_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
atanh
(
x
);
...
...
@@ -765,9 +767,9 @@ struct SinH
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same
_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
sinh
(
x
);
...
...
@@ -779,9 +781,9 @@ struct Ceil
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same
_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
ceil
(
x
);
...
...
@@ -793,9 +795,9 @@ struct Exp
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same
_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
exp
(
x
);
...
...
@@ -807,9 +809,9 @@ struct CosH
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same
_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
cosh
(
x
);
...
...
@@ -821,9 +823,9 @@ struct Floor
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same
_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
floor
(
x
);
...
...
@@ -835,9 +837,9 @@ struct Log
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same
_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
log
(
x
);
...
...
@@ -849,9 +851,9 @@ struct ASin
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same
_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
asin
(
x
);
...
...
@@ -863,9 +865,9 @@ struct Rcp
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same
_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
rcp
(
x
);
...
...
@@ -879,12 +881,12 @@ struct Swish
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST_DEVICE
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
X
,
float
>
::
value
||
ck_tile
::
is_same
<
X
,
double
>
::
value
||
ck_tile
::
is_same
<
X
,
ck_tile
::
fp16_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
X
,
float
>
||
std
::
is_same
_v
<
X
,
double
>
||
std
::
is_same
_v
<
X
,
ck_tile
::
fp16_t
>
,
"Data type is not supported by this operation!"
);
static_assert
(
ck_tile
::
is_same
<
Y
,
float
>
::
value
||
ck_tile
::
is_same
<
Y
,
double
>
::
value
||
ck_tile
::
is_same
<
Y
,
ck_tile
::
fp16_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
Y
,
float
>
||
std
::
is_same
_v
<
Y
,
double
>
||
std
::
is_same
_v
<
Y
,
ck_tile
::
fp16_t
>
,
"Data type is not supported by this operation!"
);
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
...
...
@@ -901,9 +903,9 @@ struct SoftRelu
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same
_v
<
T
,
int8_t
>
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
...
...
@@ -920,9 +922,9 @@ struct Power
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same
_v
<
T
,
int8_t
>
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
...
...
@@ -942,9 +944,9 @@ struct ClippedRelu
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same
_v
<
T
,
int8_t
>
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
...
...
@@ -961,9 +963,9 @@ struct LeakyRelu
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same
_v
<
T
,
int8_t
>
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
...
...
@@ -978,9 +980,9 @@ struct Elu
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same
_v
<
T
,
int8_t
>
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck_tile
::
expm1
(
x
);
...
...
@@ -995,9 +997,9 @@ struct Logistic
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>
::
value
||
ck_tile
::
is_same
<
T
,
double
>
::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>
::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>
::
value
,
static_assert
(
std
::
is_same
_v
<
T
,
float
>
||
std
::
is_same
_v
<
T
,
double
>
||
std
::
is_same
_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same
_v
<
T
,
int8_t
>
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
...
...
@@ -1078,7 +1080,7 @@ struct ConvScaleRelu
};
// support fastconvert of int8 to fp16
#if 0
template <typename InputDataType, typename OutputDataType, index_t RegPackNumber>
struct FastNumericArrayConverter
{
...
...
@@ -1146,6 +1148,6 @@ struct FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, N>
CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); }
};
#endif
}
// namespace element_wise
}
// namespace ck_tile
include/ck_tile/ops/reduce/block/block_reduce.hpp
View file @
667047b9
...
...
@@ -7,6 +7,10 @@
namespace
ck_tile
{
/*
* TODO: block_tile_reduce_sync() currently has a limitation
* Y dim must have at least one dim not been reduced
*/
// synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
template
<
typename
AccDistributedTensor_
,
typename
ReduceFunc
,
bool
WithBroadcast
=
true
>
CK_TILE_DEVICE
void
block_tile_reduce_sync
(
AccDistributedTensor_
&
acc_tensor
,
...
...
@@ -55,7 +59,17 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
// pull data from remote lane
const
auto
v_remote
=
warp_shuffle_down
(
v_local
,
lid_delta
);
#if 0
if constexpr(Verbose_)
{
printf("warp_shuffle_down : %d - %d, %d (%.3f, %.3f)\n",
static_cast<int>(threadIdx.x),
static_cast<int>(lid_over_rid_derivative),
static_cast<int>(lid_delta),
v_local,
v_remote);
}
#endif
// reduce
v_local
=
reduce_func
(
v_local
,
v_remote
);
});
...
...
@@ -104,6 +118,76 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
});
}
/*
* this version is faster, using xor to do reduce, no need broadcast anymore
* TODO: the limitation is to-be-reduced P dim can only mapping to one R dim?
*/
template
<
typename
AccDistributedTensor_
,
typename
ReduceFunc
>
CK_TILE_DEVICE
void
block_tile_reduce_xor_sync
(
AccDistributedTensor_
&
acc_tensor
,
const
ReduceFunc
&
reduce_func
)
{
using
Dstr
=
typename
AccDistributedTensor_
::
StaticTileDistribution
;
using
DstrEncode
=
typename
Dstr
::
DstrEncode
;
using
DstrEncodeDetail
=
typename
DstrEncode
::
detail
;
constexpr
index_t
NDimP
=
Dstr
::
get_num_of_dimension_p
();
constexpr
index_t
NDimR
=
Dstr
::
get_num_of_dimension_r
();
constexpr
index_t
idim_p_lane
=
NDimP
-
1
;
constexpr
index_t
thread_buf_size
=
AccDistributedTensor_
::
get_thread_buffer_size
();
// loop over thread data
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i
)
{
auto
v_local
=
acc_tensor
.
get_thread_buffer
()[
i
];
// cross-lane reduce for replication
// only reduce on R dimension correspond to lane
// (lane id maps to this R dimension)
static_for
<
0
,
NDimR
,
1
>
{}([
&
](
auto
idim_r
)
{
// FIXME: nasty to use does_p_own_r_
if
constexpr
(
DstrEncodeDetail
::
does_p_own_r_
[
idim_p_lane
][
idim_r
])
{
constexpr
index_t
r_length
=
DstrEncode
::
rs_lengths_
[
idim_r
];
constexpr
index_t
lid_over_rid_derivative
=
DstrEncodeDetail
::
ps_over_rs_derivative_
[
idim_p_lane
][
idim_r
];
static_assert
(
is_power_of_two_integer
(
r_length
),
"wrong! only support power of 2 reduction"
);
constexpr
index_t
nstage
=
integer_log2_floor
(
r_length
);
// reduction sweep forward
static_for
<
0
,
nstage
,
1
>
{}([
&
](
auto
istage
)
{
// TODO: lid_over_rid_derivative not ok in xor? maybe need limit the usage of
// xor
index_t
src_lane
=
(
__lane_id
()
*
lid_over_rid_derivative
)
^
(
number
<
1
<<
istage
.
value
>
{}.
value
);
// pull data from remote lane
const
auto
v_remote
=
warp_shuffle
(
v_local
,
src_lane
);
#if 0
if constexpr(Verbose_)
{
printf("block_tile_reduce_xor_sync : %d - %d, %d (%.3f, %.3f)\n",
static_cast<int>(threadIdx.x),
static_cast<int>(istage),
static_cast<int>(src_lane),
v_local,
v_remote);
}
#endif
// reduce
v_local
=
reduce_func
(
v_local
,
v_remote
);
});
}
});
acc_tensor
.
get_thread_buffer
()(
i
)
=
v_local
;
});
}
// FIXME: this is for 2D to 1D reduce only, need to support n-D
template
<
typename
AccDistributedTensor_
,
typename
InDistributedTensor_
,
...
...
@@ -175,6 +259,10 @@ CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor,
#endif
}
/*
* TODO: block_tile_reduce() currently has a limitation
* Y dim must have at least one dim not been reduced
*/
template
<
typename
AccDataType_
,
typename
InDistributedTensor_
,
index_t
...
InReduceDims
,
...
...
include/ck_tile/ops/softmax.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/softmax/block/block_softmax_2d.hpp"
#include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/softmax/block/block_softmax_2d.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
namespace
ck_tile
{
/*
simple 2d softmax implementation, along row (dim=1)
requirement:
1). each row is within a warp
2). data type must be a dword
*/
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
BlockSoftmax2D
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
DataType
=
typename
Problem
::
DataType
;
template
<
typename
DistributedTensor
,
index_t
dim
=
1
>
CK_TILE_DEVICE
void
operator
()(
const
DistributedTensor
&
x
,
DistributedTensor
&
y
,
number
<
dim
>
=
{})
{
const
auto
f_max
=
[](
auto
e0
,
auto
e1
)
{
return
max
(
e0
,
e1
);
};
const
auto
f_sum
=
[](
auto
e0
,
auto
e1
)
{
return
e0
+
e1
;
};
// compute row max
auto
row_max
=
block_tile_reduce
<
DataType
>
(
x
,
sequence
<
dim
>
{},
f_max
,
-
numeric
<
DataType
>::
infinity
());
block_tile_reduce_xor_sync
(
row_max
,
f_max
);
// compute elementwise softmax
constexpr
auto
span_2d
=
DistributedTensor
::
get_distributed_spans
();
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
y
(
i_j_idx
)
=
exp
(
x
[
i_j_idx
]
-
row_max
(
i_idx
));
});
});
// compute row sum
auto
row_sum
=
block_tile_reduce
<
DataType
>
(
y
,
sequence
<
dim
>
{},
f_sum
,
DataType
{
0
});
block_tile_reduce_xor_sync
(
row_sum
,
f_sum
);
// reciprocal
auto
r
=
make_static_distributed_tensor
<
DataType
>
(
row_sum
.
get_tile_distribution
());
constexpr
auto
span_1d
=
decltype
(
r
)
::
get_distributed_spans
();
sweep_tile_span
(
span_1d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
r
(
i_idx
)
=
DataType
{
1
}
/
row_sum
(
i_idx
);
});
// scale
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
y
(
i_j_idx
)
=
y
(
i_j_idx
)
*
r
(
i_idx
);
});
});
}
template
<
typename
DistributedTensor
,
index_t
dim
=
1
>
CK_TILE_DEVICE
decltype
(
auto
)
operator
()(
const
DistributedTensor
&
x
,
number
<
dim
>
=
{})
{
auto
y
=
DistributedTensor
{};
// distributed tensor
operator
()(
x
,
y
,
number
<
dim
>
{});
return
y
;
}
};
}
// namespace ck_tile
include/ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
DataType_
>
struct
BlockSoftmax2DProblem
{
using
DataType
=
remove_cvref_t
<
DataType_
>
;
};
}
// namespace ck_tile
include/ck_tile/ops/topk.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp"
#include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
/*
simple 2d topk implementation, along row (dim=1)
requirement:
1). each row is within a warp
*/
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
BlockTopkStream2D
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
DataType
=
typename
Problem
::
DataType
;
using
IndexType
=
typename
Problem
::
IndexType
;
// TODO: if DataType is subdword, need pack into single dword to use argmax
struct
ArgmaxPacket
{
DataType
arg
;
index_t
value
;
};
template
<
typename
DistributedTensor
,
typename
OutWindow
,
typename
IdxWindow
,
index_t
dim
=
1
>
CK_TILE_DEVICE
void
operator
()(
const
DistributedTensor
&
x
,
OutWindow
&
out_window
,
IdxWindow
&
idx_window
,
index_t
k
,
number
<
dim
>
=
{})
{
// static_assert(OutWindow::get_window_lengths()[number<1>] == 1);
static_assert
(
std
::
is_same_v
<
typename
DistributedTensor
::
DataType
,
typename
OutWindow
::
DataType
>
&&
std
::
is_same_v
<
typename
DistributedTensor
::
DataType
,
DataType
>
);
static_assert
(
std
::
is_same_v
<
typename
IdxWindow
::
DataType
,
IndexType
>
);
DistributedTensor
x_tmp
=
x
;
constexpr
auto
dst_dist
=
typename
IdxWindow
::
TileDstr
{};
// argmax for topk
const
auto
f_argmax
=
[](
ArgmaxPacket
e0
,
ArgmaxPacket
e1
)
{
return
e0
.
arg
>
e1
.
arg
?
e0
:
e1
;
};
for
(
index_t
i_k
=
0
;
i_k
<
k
;
i_k
++
)
{
constexpr
auto
span_2d
=
DistributedTensor
::
get_distributed_spans
();
auto
packet
=
[
&
]()
{
auto
tmp
=
make_static_distributed_tensor
<
ArgmaxPacket
>
(
x
.
get_tile_distribution
());
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
tmp
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
ArgmaxPacket
t
;
t
.
arg
=
x_tmp
(
i_j_idx
);
// !!! we reference x here
t
.
value
=
tile_idx
.
at
(
number
<
1
>
{});
tmp
(
i_j_idx
)
=
t
;
});
});
return
tmp
;
}();
auto
argmax_init
=
ArgmaxPacket
{
-
numeric
<
DataType
>::
infinity
(),
0
};
auto
r
=
block_tile_reduce
<
ArgmaxPacket
>
(
packet
,
sequence
<
1
>
{},
f_argmax
,
argmax_init
);
block_tile_reduce_xor_sync
(
r
,
f_argmax
);
auto
o
=
make_static_distributed_tensor
<
DataType
>
(
dst_dist
);
auto
i
=
make_static_distributed_tensor
<
IndexType
>
(
dst_dist
);
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
ArgmaxPacket
tmp
=
r
(
i_j_idx
);
o
(
i_j_idx
)
=
tmp
.
arg
;
i
(
i_j_idx
)
=
tmp
.
value
;
});
});
// update value
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
x
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
auto
col_id
=
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
x_tmp
(
i_j_idx
)
=
(
col_id
==
r
(
i_j_idx
).
value
)
?
-
numeric
<
DataType
>::
infinity
()
:
x_tmp
(
i_j_idx
);
});
});
if
(
threadIdx
.
x
%
Problem
::
ColLanes
==
0
)
{
store_tile
(
out_window
,
o
);
store_tile
(
idx_window
,
i
);
}
move_tile_window
(
out_window
,
{
number
<
0
>
{},
number
<
1
>
{}});
move_tile_window
(
idx_window
,
{
number
<
0
>
{},
number
<
1
>
{}});
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
/*
simple 2d topk implementation, along row (dim=1)
requirement:
1). each row is within a warp
*/
template
<
typename
DataType_
,
typename
IndexType_
,
index_t
ColLanes_
>
struct
BlockTopkStream2DProblem
{
using
DataType
=
remove_cvref_t
<
DataType_
>
;
using
IndexType
=
remove_cvref_t
<
IndexType_
>
;
static
constexpr
index_t
ColLanes
=
ColLanes_
;
};
}
// namespace ck_tile
include/ck_tile/ops/topk_softmax.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
struct
TopkSoftmaxHostArgs
{
const
void
*
p_input
;
void
*
p_output
;
void
*
p_indices
;
index_t
num_rows
;
index_t
num_experts
;
index_t
topk
;
};
template
<
typename
Pipeline_
>
struct
TopkSoftmaxKernel
{
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Problem
=
remove_cvref_t
<
typename
Pipeline
::
Problem
>
;
using
InputType
=
typename
Problem
::
InputType
;
using
WeightType
=
typename
Problem
::
WeightType
;
using
IndexType
=
typename
Problem
::
IndexType
;
struct
TopkSoftmaxKargs
{
const
void
*
p_input
;
void
*
p_output
;
void
*
p_indices
;
index_t
num_rows
;
index_t
num_experts
;
index_t
topk
;
};
using
Kargs
=
TopkSoftmaxKargs
;
using
Hargs
=
TopkSoftmaxHostArgs
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
h
)
{
const
int
num_warps
=
(
h
.
num_rows
+
Problem
::
RowsPerWarp
-
1
)
/
Problem
::
RowsPerWarp
;
const
int
num_blocks
=
(
num_warps
+
Problem
::
WarpsPerBlock
-
1
)
/
Problem
::
WarpsPerBlock
;
return
dim3
(
num_blocks
);
}
CK_TILE_HOST
static
constexpr
auto
MakeKargs
(
const
Hargs
&
h
)
{
Kargs
k
;
k
.
p_input
=
h
.
p_input
;
k
.
p_output
=
h
.
p_output
;
k
.
p_indices
=
h
.
p_indices
;
k
.
num_rows
=
h
.
num_rows
;
k
.
num_experts
=
h
.
num_experts
;
k
.
topk
=
h
.
topk
;
return
k
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockSize
;
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
index_t
block_row_id
=
static_cast
<
index_t
>
(
blockIdx
.
x
*
Problem
::
RowsPerBlock
);
const
auto
input_window
=
[
&
]()
{
const
InputType
*
p_input
=
reinterpret_cast
<
const
InputType
*>
(
kargs
.
p_input
)
+
blockIdx
.
x
*
Problem
::
RowsPerBlock
*
kargs
.
num_experts
;
auto
tmp
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
p_input
,
make_tuple
(
kargs
.
num_rows
,
kargs
.
num_experts
),
number
<
Problem
::
VectorSize
>
{});
auto
view
=
pad_tensor_view
(
tmp
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
Problem
::
Experts
>
{}),
sequence
<
1
,
1
>
{});
return
make_tile_window
(
view
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
Problem
::
Experts
>
{}),
{
block_row_id
,
0
});
}();
auto
output_window
=
[
&
]()
{
WeightType
*
p_output
=
reinterpret_cast
<
WeightType
*>
(
kargs
.
p_output
)
+
blockIdx
.
x
*
Problem
::
RowsPerBlock
*
kargs
.
topk
;
auto
tmp
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
p_output
,
make_tuple
(
kargs
.
num_rows
,
kargs
.
topk
),
number
<
Problem
::
VectorSize
>
{});
auto
view
=
pad_tensor_view
(
tmp
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
sequence
<
1
,
0
>
{});
return
make_tile_window
(
view
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
{
block_row_id
,
0
});
}();
auto
indices_window
=
[
&
]()
{
IndexType
*
p_indices
=
reinterpret_cast
<
IndexType
*>
(
kargs
.
p_indices
)
+
blockIdx
.
x
*
Problem
::
RowsPerBlock
*
kargs
.
topk
;
auto
tmp
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
p_indices
,
make_tuple
(
kargs
.
num_rows
,
kargs
.
topk
),
number
<
Problem
::
VectorSize
>
{});
auto
view
=
pad_tensor_view
(
tmp
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
sequence
<
1
,
0
>
{});
return
make_tile_window
(
view
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
{
block_row_id
,
0
});
}();
Pipeline
{}(
input_window
,
output_window
,
indices_window
,
kargs
.
topk
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
TopkSoftmaxWarpPerRowPolicy
>
struct
TopkSoftmaxWarpPerRowPipeline
{
// TODO: this kernel only support warp per row
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
template
<
typename
InputWindow
,
typename
OutputWindow
,
typename
IndexWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
InputWindow
&
input_window
,
OutputWindow
&
out_window
,
IndexWindow
&
idx_window
,
index_t
k
)
{
auto
input_win
=
make_tile_window
(
input_window
.
get_bottom_tensor_view
(),
input_window
.
get_window_lengths
(),
input_window
.
get_window_origin
(),
Policy
::
template
MakeInputDistribution
<
Problem
>());
auto
x
=
load_tile
(
input_win
);
auto
w
=
cast_tile
<
typename
Problem
::
WeightType
>
(
x
);
auto
softmax
=
Policy
::
template
GetSoftmax
<
Problem
>();
// softmax
auto
y
=
softmax
(
w
);
auto
topk
=
Policy
::
template
GetTopk
<
Problem
>();
auto
out_win
=
make_tile_window
(
out_window
.
get_bottom_tensor_view
(),
out_window
.
get_window_lengths
(),
out_window
.
get_window_origin
(),
Policy
::
template
MakeOutputDistribution
<
Problem
>());
auto
idx_win
=
make_tile_window
(
idx_window
.
get_bottom_tensor_view
(),
idx_window
.
get_window_lengths
(),
idx_window
.
get_window_origin
(),
Policy
::
template
MakeOutputDistribution
<
Problem
>());
topk
(
y
,
out_win
,
idx_win
,
k
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/softmax.hpp"
#include "ck_tile/ops/topk.hpp"
namespace
ck_tile
{
struct
TopkSoftmaxWarpPerRowPolicy
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeInputDistribution
()
{
// TODO: Y dim must have one dim that is not reduced
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
Problem
::
IssuesPerCol
,
Problem
::
WarpsPerBlock
,
Problem
::
RowsPerWarp
>
,
sequence
<
Problem
::
IssuesPerRow
,
Problem
::
LanesPerRow
,
Problem
::
VectorSize
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
1
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
0
,
0
,
2
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOutputDistribution
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
Problem
::
LanesPerRow
>
,
// repeat this one
tuple
<
sequence
<
Problem
::
WarpsPerBlock
,
Problem
::
RowsPerWarp
>
,
sequence
<
1
>>
,
// each row write out single element
tuple
<
sequence
<
1
>
,
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
2
>
,
sequence
<
0
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSoftmax
()
{
using
softmax_problem
=
BlockSoftmax2DProblem
<
typename
Problem
::
WeightType
>
;
return
BlockSoftmax2D
<
softmax_problem
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetTopk
()
{
using
topk_problem
=
BlockTopkStream2DProblem
<
typename
Problem
::
WeightType
,
typename
Problem
::
IndexType
,
Problem
::
LanesPerRow
>
;
// Note: replicate is LanesPerRow
return
BlockTopkStream2D
<
topk_problem
>
{};
}
};
}
// namespace ck_tile
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment