Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
667047b9
Commit
667047b9
authored
Sep 06, 2024
by
carlushuang
Browse files
topk-softmax
parent
840cba8e
Changes
25
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
896 additions
and
125 deletions
+896
-125
example/ck_tile/05_moe/fused_moe_interface.cpp
example/ck_tile/05_moe/fused_moe_interface.cpp
+0
-0
example/ck_tile/05_moe/fused_moe_interface.hpp
example/ck_tile/05_moe/fused_moe_interface.hpp
+22
-0
example/ck_tile/05_moe/topk_softmax_api.cpp
example/ck_tile/05_moe/topk_softmax_api.cpp
+50
-0
example/ck_tile/05_moe/topk_softmax_api.hpp
example/ck_tile/05_moe/topk_softmax_api.hpp
+21
-0
include/ck_tile/core/arch/utility.hpp
include/ck_tile/core/arch/utility.hpp
+43
-0
include/ck_tile/core/numeric/math.hpp
include/ck_tile/core/numeric/math.hpp
+2
-1
include/ck_tile/host/reference/reference_softmax.hpp
include/ck_tile/host/reference/reference_softmax.hpp
+58
-20
include/ck_tile/host/reference/reference_topk.hpp
include/ck_tile/host/reference/reference_topk.hpp
+21
-0
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
.../ck_tile/ops/elementwise/unary_element_wise_operation.hpp
+105
-103
include/ck_tile/ops/reduce/block/block_reduce.hpp
include/ck_tile/ops/reduce/block/block_reduce.hpp
+89
-1
include/ck_tile/ops/softmax.hpp
include/ck_tile/ops/softmax.hpp
+8
-0
include/ck_tile/ops/softmax/block/block_softmax_2d.hpp
include/ck_tile/ops/softmax/block/block_softmax_2d.hpp
+80
-0
include/ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp
...de/ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp
+16
-0
include/ck_tile/ops/topk.hpp
include/ck_tile/ops/topk.hpp
+8
-0
include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp
include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp
+112
-0
include/ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp
...e/ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp
+22
-0
include/ck_tile/ops/topk_softmax.hpp
include/ck_tile/ops/topk_softmax.hpp
+10
-0
include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp
...e/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp
+117
-0
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
...k_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
+53
-0
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp
...opk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp
+59
-0
No files found.
example/ck_tile/05_moe/fused_moe_interface.cpp
0 → 100644
View file @
667047b9
example/ck_tile/05_moe/fused_moe_interface.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
struct
fused_moe_traits
{
int
hdim_q
;
int
hdim_v
;
std
::
string
data_type
;
bool
is_group_mode
;
bool
is_v_rowmajor
;
mask_enum
mask_type
;
bias_enum
bias_type
;
// 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool
has_lse
;
bool
has_dropout
;
bool
do_fp8_static_quant
;
// TODO: padding check is inside this api
};
float
fused_moe
(
fused_moe_traits
,
fused_moe_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/05_moe/topk_softmax_api.cpp
0 → 100644
View file @
667047b9
#include "topk_softmax_api.hpp"
float
topk_softmax
(
topk_softmax_trait
t
,
topk_softmax_kargs
a
,
ck_tile
::
stream_config
s
)
{
if
(
t
.
input_type
==
"fp16"
&&
t
.
weight_type
==
"fp32"
)
{
using
ts_input_type
=
ck_tile
::
fp16_t
;
using
ts_weight_type
=
float
;
using
ts_index_type
=
ck_tile
::
index_t
;
constexpr
ck_tile
::
index_t
ts_experts
=
8
;
using
ts_problem
=
ck_tile
::
TopkSoftmaxWarpPerRowProblem
<
ts_input_type
,
ts_weight_type
,
ts_index_type
,
ts_experts
>
;
using
ts_pipeline
=
ck_tile
::
TopkSoftmaxWarpPerRowPipeline
<
ts_problem
>
;
using
kernel
=
ck_tile
::
TopkSoftmaxKernel
<
ts_pipeline
>
;
auto
kargs
=
kernel
::
MakeKargs
(
a
);
const
dim3
grids
=
kernel
::
GridSize
(
a
);
constexpr
dim3
blocks
=
kernel
::
BlockSize
();
float
ave_time
=
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
1
>
(
kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
}
else
if
(
t
.
input_type
==
"bf16"
&&
t
.
weight_type
==
"fp32"
)
{
using
ts_input_type
=
ck_tile
::
bf16_t
;
using
ts_weight_type
=
float
;
using
ts_index_type
=
ck_tile
::
index_t
;
constexpr
ck_tile
::
index_t
ts_experts
=
8
;
using
ts_problem
=
ck_tile
::
TopkSoftmaxWarpPerRowProblem
<
ts_input_type
,
ts_weight_type
,
ts_index_type
,
ts_experts
>
;
using
ts_pipeline
=
ck_tile
::
TopkSoftmaxWarpPerRowPipeline
<
ts_problem
>
;
using
kernel
=
ck_tile
::
TopkSoftmaxKernel
<
ts_pipeline
>
;
auto
kargs
=
kernel
::
MakeKargs
(
a
);
const
dim3
grids
=
kernel
::
GridSize
(
a
);
constexpr
dim3
blocks
=
kernel
::
BlockSize
();
float
ave_time
=
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
1
>
(
kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
}
return
-
1
;
}
example/ck_tile/05_moe/topk_softmax_api.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/topk_softmax.hpp"
#include <string>
struct
topk_softmax_trait
{
std
::
string
input_type
;
std
::
string
weight_type
;
// currently always float
int
experts
;
};
struct
topk_softmax_kargs
:
public
ck_tile
::
TopkSoftmaxHostArgs
{
};
float
topk_softmax
(
topk_softmax_trait
t
,
topk_softmax_kargs
a
,
ck_tile
::
stream_config
s
);
include/ck_tile/core/arch/utility.hpp
View file @
667047b9
...
@@ -120,4 +120,47 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
...
@@ -120,4 +120,47 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
#endif
#endif
}
}
template
<
typename
T
>
CK_TILE_DEVICE
T
warp_shuffle
(
const
T
&
v_local
,
uint32_t
src_lane
)
{
#if 0
return __shfl(v_local, src_lane);
#elif
1
if
constexpr
(
sizeof
(
int32_t
)
>
sizeof
(
T
))
{
union
packet
{
int32_t
x
;
T
v
;
};
packet
p
;
p
.
v
=
v_local
;
packet
p_remote
;
p_remote
.
x
=
__builtin_amdgcn_ds_bpermute
(
src_lane
<<
2
,
bit_cast
<
int32_t
>
(
p
));
return
p_remote
.
v
;
}
else
if
constexpr
(
sizeof
(
int32_t
)
==
sizeof
(
T
))
{
const
int32_t
v_remote_tmp
=
__builtin_amdgcn_ds_bpermute
(
src_lane
<<
2
,
bit_cast
<
int32_t
>
(
v_local
));
return
bit_cast
<
T
>
(
v_remote_tmp
);
}
else
{
static_assert
(
sizeof
(
T
)
%
sizeof
(
int32_t
)
==
0
,
"wrong!"
);
constexpr
index_t
elm
=
sizeof
(
T
)
/
sizeof
(
int32_t
);
using
vector_type
=
thread_buffer
<
int32_t
,
elm
>
;
auto
vs
=
bit_cast
<
vector_type
>
(
v_local
);
auto
vs_remote
=
vector_type
{};
static_for
<
0
,
elm
,
1
>
{}([
&
](
auto
i_e
)
{
int32_t
tmp
=
__builtin_amdgcn_ds_bpermute
(
src_lane
<<
2
,
bit_cast
<
int32_t
>
(
vs
[
i_e
]));
vs_remote
(
i_e
)
=
tmp
;
});
return
bit_cast
<
T
>
(
vs_remote
);
}
#endif
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/numeric/math.hpp
View file @
667047b9
...
@@ -1406,7 +1406,8 @@ CK_TILE_DEVICE T rcp(T x)
...
@@ -1406,7 +1406,8 @@ CK_TILE_DEVICE T rcp(T x)
#if !CK_TILE_WORKAROUND_SWDEV_383542
#if !CK_TILE_WORKAROUND_SWDEV_383542
return
__frcp_rn
(
x
);
return
__frcp_rn
(
x
);
#else
#else
return
__ocml_native_recip_f32
(
x
);
// return __ocml_native_recip_f32(x);
return
__builtin_amdgcn_rcpf
(
x
);
#endif
#endif
};
};
...
...
include/ck_tile/host/reference/reference_softmax.hpp
View file @
667047b9
...
@@ -9,43 +9,81 @@
...
@@ -9,43 +9,81 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
AData
Type
,
typename
AccData
Type
,
typename
BData
Type
>
template
<
typename
Input
Type
,
typename
Compute
Type
,
typename
OutputType
=
Compute
Type
>
CK_TILE_HOST
void
reference_softmax
(
const
HostTensor
<
ADataType
>&
a_m_n
,
CK_TILE_HOST
void
HostTensor
<
BData
Type
>&
b_m_n
)
reference_softmax
(
const
HostTensor
<
InputType
>&
x
,
HostTensor
<
Output
Type
>&
y
,
index_t
dim
=
-
1
)
{
{
auto
f
=
[
&
](
auto
m
)
{
index_t
rank
=
x
.
get_num_of_dimension
();
const
int
N
=
a_m_n
.
mDesc
.
get_lengths
()[
1
];
assert
(
rank
==
y
.
get_num_of_dimension
());
assert
(
dim
==
-
1
||
dim
<
rank
);
AccDataType
v_max
=
ck_tile
::
numeric
<
ADataType
>::
Lowest
();
index_t
target_dim
=
dim
==
-
1
?
(
rank
-
1
)
:
dim
;
index_t
softmax_len
=
x
.
get_length
(
target_dim
);
index_t
n_parallel
=
x
.
get_element_size
()
/
softmax_len
;
auto
x_len
=
x
.
get_lengths
();
// max
auto
f
=
[
&
](
auto
i_element
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
std
::
vector
<
size_t
>
coord
=
[
&
]()
{
std
::
vector
<
size_t
>
t_
(
rank
,
0
);
size_t
r
=
i_element
;
for
(
index_t
i
=
rank
-
1
;
i
>=
0
;
i
--
)
{
{
const
ADataType
v_a
=
a_m_n
(
m
,
n
);
if
(
i
==
target_dim
)
continue
;
t_
[
i
]
=
r
%
x_len
[
i
];
r
=
r
/
x_len
[
i
];
}
return
t_
;
}();
ComputeType
v_max
=
-
ck_tile
::
numeric
<
ComputeType
>::
infinity
();
v_max
=
v_max
<
v_a
?
v_a
:
v_max
;
// compute max
for
(
auto
idx
=
0
;
idx
<
softmax_len
;
idx
++
)
{
auto
c_
=
coord
;
c_
[
target_dim
]
=
idx
;
const
ComputeType
v_x
=
ck_tile
::
type_convert
<
ComputeType
>
(
x
(
c_
));
v_max
=
v_max
<
v_x
?
v_x
:
v_max
;
}
}
AccData
Type
v_exp_sum
=
0
;
Compute
Type
v_exp_sum
=
static_cast
<
ComputeType
>
(
0
)
;
// sum
// sum
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
auto
idx
=
0
;
idx
<
softmax_len
;
idx
++
)
{
{
const
ADataType
v_a
=
a_m_n
(
m
,
n
);
auto
c_
=
coord
;
c_
[
target_dim
]
=
idx
;
v_exp_sum
+=
ck_tile
::
exp
(
v_a
-
v_max
);
const
ComputeType
v_x
=
ck_tile
::
type_convert
<
ComputeType
>
(
x
(
c_
));
v_exp_sum
+=
ck_tile
::
exp
(
v_x
-
v_max
);
}
}
// elementwise
// elementwise
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
auto
idx
=
0
;
idx
<
softmax_len
;
idx
++
)
{
{
const
ADataType
v_a
=
a_m_n
(
m
,
n
);
auto
c_
=
coord
;
c_
[
target_dim
]
=
idx
;
const
ComputeType
v_x
=
ck_tile
::
type_convert
<
ComputeType
>
(
x
(
c_
));
auto
out
=
ck_tile
::
exp
(
v_x
-
v_max
)
/
v_exp_sum
;
b_m_n
(
m
,
n
)
=
ck_tile
::
exp
(
v_a
-
v_max
)
/
v_exp_sum
;
y
(
c_
)
=
ck_tile
::
type_convert
<
OutputType
>
(
out
)
;
}
}
};
};
make_ParallelTensorFunctor
(
f
,
make_ParallelTensorFunctor
(
f
,
n_parallel
)(
std
::
thread
::
hardware_concurrency
());
b_m_n
.
mDesc
.
get_lengths
()[
0
])(
std
::
thread
::
hardware_concurrency
());
}
template
<
typename
InputType
,
typename
ComputeType
,
typename
OutputType
=
ComputeType
>
CK_TILE_HOST
auto
reference_softmax
(
const
HostTensor
<
InputType
>&
x
,
index_t
dim
=
-
1
)
{
HostTensor
<
OutputType
>
y
(
x
.
get_lengths
(),
x
.
get_strides
());
reference_softmax
<
InputType
,
ComputeType
,
OutputType
>
(
x
,
y
,
dim
);
return
y
;
}
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/host/reference/reference_topk.hpp
View file @
667047b9
...
@@ -100,4 +100,25 @@ CK_TILE_HOST void reference_topk(const HostTensor<DataType>& x,
...
@@ -100,4 +100,25 @@ CK_TILE_HOST void reference_topk(const HostTensor<DataType>& x,
make_ParallelTensorFunctor
(
f
,
n_parallel
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f
,
n_parallel
)(
std
::
thread
::
hardware_concurrency
());
}
}
// TODO: if using this method, the return tensor would be dense(no stride)
template
<
typename
DataType
,
typename
IndexType
=
index_t
>
CK_TILE_HOST
auto
reference_topk
(
const
HostTensor
<
DataType
>&
x
,
index_t
k
,
index_t
dim
=
-
1
,
bool
largest
=
true
,
bool
sorted
=
true
)
{
auto
lens
=
x
.
get_lengths
();
index_t
target_dim
=
(
dim
==
-
1
)
?
(
lens
.
size
()
-
1
)
:
dim
;
assert
(
target_dim
<
lens
.
size
());
assert
(
k
<=
lens
[
target_dim
]);
lens
[
target_dim
]
=
k
;
HostTensor
<
DataType
>
y_values
(
lens
);
HostTensor
<
IndexType
>
y_indices
(
lens
);
reference_topk
<
DataType
,
IndexType
>
(
x
,
y_values
,
y_indices
,
k
,
dim
,
largest
,
sorted
);
return
ck_tile
::
make_tuple
(
y_values
,
y_indices
);
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
View file @
667047b9
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/reduce/block/block_reduce.hpp
View file @
667047b9
...
@@ -7,6 +7,10 @@
...
@@ -7,6 +7,10 @@
namespace
ck_tile
{
namespace
ck_tile
{
/*
* TODO: block_tile_reduce_sync() currently has a limitation
* Y dim must have at least one dim not been reduced
*/
// synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
// synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
template
<
typename
AccDistributedTensor_
,
typename
ReduceFunc
,
bool
WithBroadcast
=
true
>
template
<
typename
AccDistributedTensor_
,
typename
ReduceFunc
,
bool
WithBroadcast
=
true
>
CK_TILE_DEVICE
void
block_tile_reduce_sync
(
AccDistributedTensor_
&
acc_tensor
,
CK_TILE_DEVICE
void
block_tile_reduce_sync
(
AccDistributedTensor_
&
acc_tensor
,
...
@@ -55,7 +59,17 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
...
@@ -55,7 +59,17 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
// pull data from remote lane
// pull data from remote lane
const
auto
v_remote
=
warp_shuffle_down
(
v_local
,
lid_delta
);
const
auto
v_remote
=
warp_shuffle_down
(
v_local
,
lid_delta
);
#if 0
if constexpr(Verbose_)
{
printf("warp_shuffle_down : %d - %d, %d (%.3f, %.3f)\n",
static_cast<int>(threadIdx.x),
static_cast<int>(lid_over_rid_derivative),
static_cast<int>(lid_delta),
v_local,
v_remote);
}
#endif
// reduce
// reduce
v_local
=
reduce_func
(
v_local
,
v_remote
);
v_local
=
reduce_func
(
v_local
,
v_remote
);
});
});
...
@@ -104,6 +118,76 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
...
@@ -104,6 +118,76 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
});
});
}
}
/*
* this version is faster, using xor to do reduce, no need broadcast anymore
* TODO: the limitation is to-be-reduced P dim can only mapping to one R dim?
*/
template
<
typename
AccDistributedTensor_
,
typename
ReduceFunc
>
CK_TILE_DEVICE
void
block_tile_reduce_xor_sync
(
AccDistributedTensor_
&
acc_tensor
,
const
ReduceFunc
&
reduce_func
)
{
using
Dstr
=
typename
AccDistributedTensor_
::
StaticTileDistribution
;
using
DstrEncode
=
typename
Dstr
::
DstrEncode
;
using
DstrEncodeDetail
=
typename
DstrEncode
::
detail
;
constexpr
index_t
NDimP
=
Dstr
::
get_num_of_dimension_p
();
constexpr
index_t
NDimR
=
Dstr
::
get_num_of_dimension_r
();
constexpr
index_t
idim_p_lane
=
NDimP
-
1
;
constexpr
index_t
thread_buf_size
=
AccDistributedTensor_
::
get_thread_buffer_size
();
// loop over thread data
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i
)
{
auto
v_local
=
acc_tensor
.
get_thread_buffer
()[
i
];
// cross-lane reduce for replication
// only reduce on R dimension correspond to lane
// (lane id maps to this R dimension)
static_for
<
0
,
NDimR
,
1
>
{}([
&
](
auto
idim_r
)
{
// FIXME: nasty to use does_p_own_r_
if
constexpr
(
DstrEncodeDetail
::
does_p_own_r_
[
idim_p_lane
][
idim_r
])
{
constexpr
index_t
r_length
=
DstrEncode
::
rs_lengths_
[
idim_r
];
constexpr
index_t
lid_over_rid_derivative
=
DstrEncodeDetail
::
ps_over_rs_derivative_
[
idim_p_lane
][
idim_r
];
static_assert
(
is_power_of_two_integer
(
r_length
),
"wrong! only support power of 2 reduction"
);
constexpr
index_t
nstage
=
integer_log2_floor
(
r_length
);
// reduction sweep forward
static_for
<
0
,
nstage
,
1
>
{}([
&
](
auto
istage
)
{
// TODO: lid_over_rid_derivative not ok in xor? maybe need limit the usage of
// xor
index_t
src_lane
=
(
__lane_id
()
*
lid_over_rid_derivative
)
^
(
number
<
1
<<
istage
.
value
>
{}.
value
);
// pull data from remote lane
const
auto
v_remote
=
warp_shuffle
(
v_local
,
src_lane
);
#if 0
if constexpr(Verbose_)
{
printf("block_tile_reduce_xor_sync : %d - %d, %d (%.3f, %.3f)\n",
static_cast<int>(threadIdx.x),
static_cast<int>(istage),
static_cast<int>(src_lane),
v_local,
v_remote);
}
#endif
// reduce
v_local
=
reduce_func
(
v_local
,
v_remote
);
});
}
});
acc_tensor
.
get_thread_buffer
()(
i
)
=
v_local
;
});
}
// FIXME: this is for 2D to 1D reduce only, need to support n-D
// FIXME: this is for 2D to 1D reduce only, need to support n-D
template
<
typename
AccDistributedTensor_
,
template
<
typename
AccDistributedTensor_
,
typename
InDistributedTensor_
,
typename
InDistributedTensor_
,
...
@@ -175,6 +259,10 @@ CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor,
...
@@ -175,6 +259,10 @@ CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor,
#endif
#endif
}
}
/*
* TODO: block_tile_reduce() currently has a limitation
* Y dim must have at least one dim not been reduced
*/
template
<
typename
AccDataType_
,
template
<
typename
AccDataType_
,
typename
InDistributedTensor_
,
typename
InDistributedTensor_
,
index_t
...
InReduceDims
,
index_t
...
InReduceDims
,
...
...
include/ck_tile/ops/softmax.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/softmax/block/block_softmax_2d.hpp"
#include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/softmax/block/block_softmax_2d.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
namespace
ck_tile
{
/*
simple 2d softmax implementation, along row (dim=1)
requirement:
1). each row is within a warp
2). data type must be a dword
*/
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
BlockSoftmax2D
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
DataType
=
typename
Problem
::
DataType
;
template
<
typename
DistributedTensor
,
index_t
dim
=
1
>
CK_TILE_DEVICE
void
operator
()(
const
DistributedTensor
&
x
,
DistributedTensor
&
y
,
number
<
dim
>
=
{})
{
const
auto
f_max
=
[](
auto
e0
,
auto
e1
)
{
return
max
(
e0
,
e1
);
};
const
auto
f_sum
=
[](
auto
e0
,
auto
e1
)
{
return
e0
+
e1
;
};
// compute row max
auto
row_max
=
block_tile_reduce
<
DataType
>
(
x
,
sequence
<
dim
>
{},
f_max
,
-
numeric
<
DataType
>::
infinity
());
block_tile_reduce_xor_sync
(
row_max
,
f_max
);
// compute elementwise softmax
constexpr
auto
span_2d
=
DistributedTensor
::
get_distributed_spans
();
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
y
(
i_j_idx
)
=
exp
(
x
[
i_j_idx
]
-
row_max
(
i_idx
));
});
});
// compute row sum
auto
row_sum
=
block_tile_reduce
<
DataType
>
(
y
,
sequence
<
dim
>
{},
f_sum
,
DataType
{
0
});
block_tile_reduce_xor_sync
(
row_sum
,
f_sum
);
// reciprocal
auto
r
=
make_static_distributed_tensor
<
DataType
>
(
row_sum
.
get_tile_distribution
());
constexpr
auto
span_1d
=
decltype
(
r
)
::
get_distributed_spans
();
sweep_tile_span
(
span_1d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
r
(
i_idx
)
=
DataType
{
1
}
/
row_sum
(
i_idx
);
});
// scale
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
y
(
i_j_idx
)
=
y
(
i_j_idx
)
*
r
(
i_idx
);
});
});
}
template
<
typename
DistributedTensor
,
index_t
dim
=
1
>
CK_TILE_DEVICE
decltype
(
auto
)
operator
()(
const
DistributedTensor
&
x
,
number
<
dim
>
=
{})
{
auto
y
=
DistributedTensor
{};
// distributed tensor
operator
()(
x
,
y
,
number
<
dim
>
{});
return
y
;
}
};
}
// namespace ck_tile
include/ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
DataType_
>
struct
BlockSoftmax2DProblem
{
using
DataType
=
remove_cvref_t
<
DataType_
>
;
};
}
// namespace ck_tile
include/ck_tile/ops/topk.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp"
#include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
/*
simple 2d topk implementation, along row (dim=1)
requirement:
1). each row is within a warp
*/
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
BlockTopkStream2D
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
DataType
=
typename
Problem
::
DataType
;
using
IndexType
=
typename
Problem
::
IndexType
;
// TODO: if DataType is subdword, need pack into single dword to use argmax
struct
ArgmaxPacket
{
DataType
arg
;
index_t
value
;
};
template
<
typename
DistributedTensor
,
typename
OutWindow
,
typename
IdxWindow
,
index_t
dim
=
1
>
CK_TILE_DEVICE
void
operator
()(
const
DistributedTensor
&
x
,
OutWindow
&
out_window
,
IdxWindow
&
idx_window
,
index_t
k
,
number
<
dim
>
=
{})
{
// static_assert(OutWindow::get_window_lengths()[number<1>] == 1);
static_assert
(
std
::
is_same_v
<
typename
DistributedTensor
::
DataType
,
typename
OutWindow
::
DataType
>
&&
std
::
is_same_v
<
typename
DistributedTensor
::
DataType
,
DataType
>
);
static_assert
(
std
::
is_same_v
<
typename
IdxWindow
::
DataType
,
IndexType
>
);
DistributedTensor
x_tmp
=
x
;
constexpr
auto
dst_dist
=
typename
IdxWindow
::
TileDstr
{};
// argmax for topk
const
auto
f_argmax
=
[](
ArgmaxPacket
e0
,
ArgmaxPacket
e1
)
{
return
e0
.
arg
>
e1
.
arg
?
e0
:
e1
;
};
for
(
index_t
i_k
=
0
;
i_k
<
k
;
i_k
++
)
{
constexpr
auto
span_2d
=
DistributedTensor
::
get_distributed_spans
();
auto
packet
=
[
&
]()
{
auto
tmp
=
make_static_distributed_tensor
<
ArgmaxPacket
>
(
x
.
get_tile_distribution
());
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
tmp
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
ArgmaxPacket
t
;
t
.
arg
=
x_tmp
(
i_j_idx
);
// !!! we reference x here
t
.
value
=
tile_idx
.
at
(
number
<
1
>
{});
tmp
(
i_j_idx
)
=
t
;
});
});
return
tmp
;
}();
auto
argmax_init
=
ArgmaxPacket
{
-
numeric
<
DataType
>::
infinity
(),
0
};
auto
r
=
block_tile_reduce
<
ArgmaxPacket
>
(
packet
,
sequence
<
1
>
{},
f_argmax
,
argmax_init
);
block_tile_reduce_xor_sync
(
r
,
f_argmax
);
auto
o
=
make_static_distributed_tensor
<
DataType
>
(
dst_dist
);
auto
i
=
make_static_distributed_tensor
<
IndexType
>
(
dst_dist
);
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
ArgmaxPacket
tmp
=
r
(
i_j_idx
);
o
(
i_j_idx
)
=
tmp
.
arg
;
i
(
i_j_idx
)
=
tmp
.
value
;
});
});
// update value
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
x
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
auto
col_id
=
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
x_tmp
(
i_j_idx
)
=
(
col_id
==
r
(
i_j_idx
).
value
)
?
-
numeric
<
DataType
>::
infinity
()
:
x_tmp
(
i_j_idx
);
});
});
if
(
threadIdx
.
x
%
Problem
::
ColLanes
==
0
)
{
store_tile
(
out_window
,
o
);
store_tile
(
idx_window
,
i
);
}
move_tile_window
(
out_window
,
{
number
<
0
>
{},
number
<
1
>
{}});
move_tile_window
(
idx_window
,
{
number
<
0
>
{},
number
<
1
>
{}});
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
/*
simple 2d topk implementation, along row (dim=1)
requirement:
1). each row is within a warp
*/
template
<
typename
DataType_
,
typename
IndexType_
,
index_t
ColLanes_
>
struct
BlockTopkStream2DProblem
{
using
DataType
=
remove_cvref_t
<
DataType_
>
;
using
IndexType
=
remove_cvref_t
<
IndexType_
>
;
static
constexpr
index_t
ColLanes
=
ColLanes_
;
};
}
// namespace ck_tile
include/ck_tile/ops/topk_softmax.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
struct
TopkSoftmaxHostArgs
{
const
void
*
p_input
;
void
*
p_output
;
void
*
p_indices
;
index_t
num_rows
;
index_t
num_experts
;
index_t
topk
;
};
template
<
typename
Pipeline_
>
struct
TopkSoftmaxKernel
{
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Problem
=
remove_cvref_t
<
typename
Pipeline
::
Problem
>
;
using
InputType
=
typename
Problem
::
InputType
;
using
WeightType
=
typename
Problem
::
WeightType
;
using
IndexType
=
typename
Problem
::
IndexType
;
struct
TopkSoftmaxKargs
{
const
void
*
p_input
;
void
*
p_output
;
void
*
p_indices
;
index_t
num_rows
;
index_t
num_experts
;
index_t
topk
;
};
using
Kargs
=
TopkSoftmaxKargs
;
using
Hargs
=
TopkSoftmaxHostArgs
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
h
)
{
const
int
num_warps
=
(
h
.
num_rows
+
Problem
::
RowsPerWarp
-
1
)
/
Problem
::
RowsPerWarp
;
const
int
num_blocks
=
(
num_warps
+
Problem
::
WarpsPerBlock
-
1
)
/
Problem
::
WarpsPerBlock
;
return
dim3
(
num_blocks
);
}
CK_TILE_HOST
static
constexpr
auto
MakeKargs
(
const
Hargs
&
h
)
{
Kargs
k
;
k
.
p_input
=
h
.
p_input
;
k
.
p_output
=
h
.
p_output
;
k
.
p_indices
=
h
.
p_indices
;
k
.
num_rows
=
h
.
num_rows
;
k
.
num_experts
=
h
.
num_experts
;
k
.
topk
=
h
.
topk
;
return
k
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockSize
;
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
index_t
block_row_id
=
static_cast
<
index_t
>
(
blockIdx
.
x
*
Problem
::
RowsPerBlock
);
const
auto
input_window
=
[
&
]()
{
const
InputType
*
p_input
=
reinterpret_cast
<
const
InputType
*>
(
kargs
.
p_input
)
+
blockIdx
.
x
*
Problem
::
RowsPerBlock
*
kargs
.
num_experts
;
auto
tmp
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
p_input
,
make_tuple
(
kargs
.
num_rows
,
kargs
.
num_experts
),
number
<
Problem
::
VectorSize
>
{});
auto
view
=
pad_tensor_view
(
tmp
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
Problem
::
Experts
>
{}),
sequence
<
1
,
1
>
{});
return
make_tile_window
(
view
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
Problem
::
Experts
>
{}),
{
block_row_id
,
0
});
}();
auto
output_window
=
[
&
]()
{
WeightType
*
p_output
=
reinterpret_cast
<
WeightType
*>
(
kargs
.
p_output
)
+
blockIdx
.
x
*
Problem
::
RowsPerBlock
*
kargs
.
topk
;
auto
tmp
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
p_output
,
make_tuple
(
kargs
.
num_rows
,
kargs
.
topk
),
number
<
Problem
::
VectorSize
>
{});
auto
view
=
pad_tensor_view
(
tmp
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
sequence
<
1
,
0
>
{});
return
make_tile_window
(
view
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
{
block_row_id
,
0
});
}();
auto
indices_window
=
[
&
]()
{
IndexType
*
p_indices
=
reinterpret_cast
<
IndexType
*>
(
kargs
.
p_indices
)
+
blockIdx
.
x
*
Problem
::
RowsPerBlock
*
kargs
.
topk
;
auto
tmp
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
p_indices
,
make_tuple
(
kargs
.
num_rows
,
kargs
.
topk
),
number
<
Problem
::
VectorSize
>
{});
auto
view
=
pad_tensor_view
(
tmp
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
sequence
<
1
,
0
>
{});
return
make_tile_window
(
view
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
{
block_row_id
,
0
});
}();
Pipeline
{}(
input_window
,
output_window
,
indices_window
,
kargs
.
topk
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
TopkSoftmaxWarpPerRowPolicy
>
struct
TopkSoftmaxWarpPerRowPipeline
{
// TODO: this kernel only support warp per row
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
template
<
typename
InputWindow
,
typename
OutputWindow
,
typename
IndexWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
InputWindow
&
input_window
,
OutputWindow
&
out_window
,
IndexWindow
&
idx_window
,
index_t
k
)
{
auto
input_win
=
make_tile_window
(
input_window
.
get_bottom_tensor_view
(),
input_window
.
get_window_lengths
(),
input_window
.
get_window_origin
(),
Policy
::
template
MakeInputDistribution
<
Problem
>());
auto
x
=
load_tile
(
input_win
);
auto
w
=
cast_tile
<
typename
Problem
::
WeightType
>
(
x
);
auto
softmax
=
Policy
::
template
GetSoftmax
<
Problem
>();
// softmax
auto
y
=
softmax
(
w
);
auto
topk
=
Policy
::
template
GetTopk
<
Problem
>();
auto
out_win
=
make_tile_window
(
out_window
.
get_bottom_tensor_view
(),
out_window
.
get_window_lengths
(),
out_window
.
get_window_origin
(),
Policy
::
template
MakeOutputDistribution
<
Problem
>());
auto
idx_win
=
make_tile_window
(
idx_window
.
get_bottom_tensor_view
(),
idx_window
.
get_window_lengths
(),
idx_window
.
get_window_origin
(),
Policy
::
template
MakeOutputDistribution
<
Problem
>());
topk
(
y
,
out_win
,
idx_win
,
k
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/softmax.hpp"
#include "ck_tile/ops/topk.hpp"
namespace
ck_tile
{
struct
TopkSoftmaxWarpPerRowPolicy
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeInputDistribution
()
{
// TODO: Y dim must have one dim that is not reduced
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
Problem
::
IssuesPerCol
,
Problem
::
WarpsPerBlock
,
Problem
::
RowsPerWarp
>
,
sequence
<
Problem
::
IssuesPerRow
,
Problem
::
LanesPerRow
,
Problem
::
VectorSize
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
1
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
0
,
0
,
2
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOutputDistribution
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
Problem
::
LanesPerRow
>
,
// repeat this one
tuple
<
sequence
<
Problem
::
WarpsPerBlock
,
Problem
::
RowsPerWarp
>
,
sequence
<
1
>>
,
// each row write out single element
tuple
<
sequence
<
1
>
,
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
2
>
,
sequence
<
0
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSoftmax
()
{
using
softmax_problem
=
BlockSoftmax2DProblem
<
typename
Problem
::
WeightType
>
;
return
BlockSoftmax2D
<
softmax_problem
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetTopk
()
{
using
topk_problem
=
BlockTopkStream2DProblem
<
typename
Problem
::
WeightType
,
typename
Problem
::
IndexType
,
Problem
::
LanesPerRow
>
;
// Note: replicate is LanesPerRow
return
BlockTopkStream2D
<
topk_problem
>
{};
}
};
}
// namespace ck_tile
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment