Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
667047b9
Commit
667047b9
authored
Sep 06, 2024
by
carlushuang
Browse files
topk-softmax
parent
840cba8e
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
276 additions
and
7 deletions
+276
-7
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp
...pk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp
+44
-0
test/CMakeLists.txt
test/CMakeLists.txt
+1
-0
test/tile_reduce/tile_reduce.cpp
test/tile_reduce/tile_reduce.cpp
+7
-7
test/topk_softmax/CMakeLists.txt
test/topk_softmax/CMakeLists.txt
+3
-0
test/topk_softmax/topk_softmax.cpp
test/topk_softmax/topk_softmax.cpp
+221
-0
No files found.
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
InputType_
,
typename
WeightType_
,
typename
IndexType_
,
index_t
Experts_
,
index_t
IssuesPerCol_
=
1
,
// issue along col, to make sure block_reduce() OK
index_t
BytesPerIssue_
=
sizeof
(
InputType_
),
index_t
BlockSize_
=
256
>
struct
TopkSoftmaxWarpPerRowProblem
{
// TODO: this kernel only support warp per row
using
InputType
=
remove_cvref_t
<
InputType_
>
;
using
WeightType
=
remove_cvref_t
<
WeightType_
>
;
using
IndexType
=
remove_cvref_t
<
IndexType_
>
;
static
constexpr
index_t
Experts
=
Experts_
;
static
constexpr
index_t
BytesPerIssue
=
BytesPerIssue_
;
static
constexpr
index_t
BlockSize
=
BlockSize_
;
static
constexpr
index_t
WarpSize
=
get_warp_size
();
static_assert
(
BytesPerIssue
%
sizeof
(
InputType
)
==
0
);
static
constexpr
index_t
VectorSize
=
BytesPerIssue
/
sizeof
(
InputType
);
static_assert
(
Experts
%
VectorSize
==
0
);
static
constexpr
index_t
LanesPerRow
=
min
(
Experts
/
VectorSize
,
WarpSize
);
static_assert
(
WarpSize
%
LanesPerRow
==
0
);
static
constexpr
index_t
RowsPerWarp
=
WarpSize
/
LanesPerRow
;
static
constexpr
index_t
IssuesPerRow
=
Experts
/
(
LanesPerRow
*
VectorSize
);
static
constexpr
index_t
IssuesPerCol
=
IssuesPerCol_
;
static
constexpr
index_t
WarpsPerBlock
=
BlockSize
/
WarpSize
;
static
constexpr
index_t
RowsPerBlock
=
RowsPerWarp
*
WarpsPerBlock
;
};
}
// namespace ck_tile
test/CMakeLists.txt
View file @
667047b9
...
...
@@ -219,5 +219,6 @@ endif()
add_subdirectory
(
position_embedding
)
add_subdirectory
(
scatter_gather
)
add_subdirectory
(
topk
)
add_subdirectory
(
topk_softmax
)
add_subdirectory
(
tile_reduce
)
test/tile_reduce/tile_reduce.cpp
View file @
667047b9
...
...
@@ -50,11 +50,11 @@ __global__ void reduce_row(DataType* p_src, DataType* p_dst)
constexpr
auto
src_dist
=
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
row_repeat
,
num_warps
,
row_lanes
>
,
sequence
<
col_lanes
,
vec
>>
,
tuple
<
sequence
<
row_repeat
,
num_warps
,
row_lanes
>
,
sequence
<
1
,
col_lanes
,
vec
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
tuple
<
sequence
<
1
>
,
sequence
<
2
,
1
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
0
,
0
,
2
>>
{});
auto
src_view
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_src
,
...
...
@@ -98,7 +98,7 @@ __global__ void reduce_row(DataType* p_src, DataType* p_dst)
block_tile_reduce
<
DataType
>
(
data
,
sequence
<
1
>
{},
f_max
,
-
numeric
<
DataType
>::
infinity
());
// further reduce cross thread, Note Now the HLength of r is 1D
block_tile_reduce_sync
(
r
,
f_max
,
bool_constant
<
false
>
{}
);
block_tile_reduce_
xor_
sync
(
r
,
f_max
);
if
(
threadIdx
.
x
%
col_lanes
==
0
)
{
...
...
@@ -205,7 +205,7 @@ __global__ void reduce_row_argmax(DataType* p_src, DataType* p_dst, int* p_idx)
auto
r
=
block_tile_reduce
<
kv
>
(
kv_data
,
sequence
<
1
>
{},
f_arg_max
,
arg_max_init
);
// further reduce cross thread, Note Now the HLength of r is 1D
block_tile_reduce_sync
(
r
,
f_arg_max
,
bool_constant
<
false
>
{}
);
block_tile_reduce_
xor_
sync
(
r
,
f_arg_max
);
auto
o
=
make_static_distributed_tensor
<
DataType
>
(
dst_dist
);
auto
i
=
make_static_distributed_tensor
<
int
>
(
dst_dist
);
...
...
@@ -368,7 +368,7 @@ int main()
{
bool
r
=
true
;
r
&=
test_tile_reduce
<
32
,
64
,
float
>
();
r
&=
test_tile_reduce
<
32
,
16
,
float
,
4
>
();
r
&=
test_tile_reduce
<
32
,
8
,
float
,
4
>
();
r
&=
test_tile_reduce
<
32
,
16
,
ck_tile
::
fp16_t
,
4
>
();
r
&=
test_tile_reduce_argmax
<
32
,
16
,
float
,
4
>
();
...
...
test/topk_softmax/CMakeLists.txt
0 → 100644
View file @
667047b9
add_test_executable
(
test_topk_softmax topk_softmax.cpp
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../example/ck_tile/05_moe/topk_softmax_api.cpp
)
target_include_directories
(
test_topk_softmax PRIVATE
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../example/ck_tile/05_moe
)
target_compile_options
(
test_topk_softmax PRIVATE -v --save-temps -Wno-gnu-line-marker
)
test/topk_softmax/topk_softmax.cpp
0 → 100644
View file @
667047b9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "topk_softmax_api.hpp"
// #ifndef TEST_TOPK_SOFTMAX_VERBOSE
// #define TEST_TOPK_SOFTMAX_VERBOSE 0
// #endif
// #define BLOCK_SIZE 256
template
<
typename
T
>
void
dump_host_tensor_2d
(
const
ck_tile
::
HostTensor
<
T
>&
x
)
{
auto
len
=
x
.
get_lengths
();
assert
(
len
.
size
()
==
2
);
std
::
cout
<<
"["
;
for
(
size_t
i
=
0
;
i
<
len
[
0
];
i
++
)
{
std
::
cout
<<
"["
;
for
(
size_t
j
=
0
;
j
<
len
[
1
];
j
++
)
{
if
constexpr
(
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
)
{
auto
v
=
ck_tile
::
type_convert
<
float
>
(
x
(
std
::
vector
<
std
::
size_t
>
{
i
,
j
}));
std
::
cout
<<
v
;
if
(
j
!=
len
[
1
]
-
1
)
std
::
cout
<<
","
;
}
else
{
std
::
cout
<<
x
(
std
::
vector
<
std
::
size_t
>
{
i
,
j
})
<<
" "
;
}
}
std
::
cout
<<
"]"
;
if
(
i
!=
len
[
0
]
-
1
)
std
::
cout
<<
","
;
else
std
::
cout
<<
"]"
;
std
::
cout
<<
std
::
endl
;
}
std
::
cout
<<
"--------------------"
<<
std
::
endl
;
}
// CPU reference
template
<
typename
InputType
,
typename
WeightType
,
typename
IndexType
=
ck_tile
::
index_t
>
auto
reference_topk_softmax
(
const
ck_tile
::
HostTensor
<
InputType
>&
x
,
ck_tile
::
index_t
k
,
ck_tile
::
index_t
dim
=
-
1
,
bool
largest
=
true
,
bool
sorted
=
true
)
{
using
namespace
ck_tile
;
// dump_host_tensor_2d(x);
auto
y
=
reference_softmax
<
InputType
,
WeightType
,
WeightType
>
(
x
,
dim
);
// dump_host_tensor_2d(y);
auto
[
y_values
,
y_indices
]
=
reference_topk
(
y
,
k
,
dim
,
largest
,
sorted
);
// dump_host_tensor_2d(y_values);
// dump_host_tensor_2d(y_indices);
return
ck_tile
::
make_tuple
(
y_values
,
y_indices
);
}
// different threshold for different dtype
template
<
typename
DataType
>
auto
get_elimit
(
std
::
string
/*init_method*/
)
{
double
rtol
=
1e-3
;
double
atol
=
1e-3
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
template
<
>
auto
get_elimit
<
ck_tile
::
bf16_t
>
(
std
::
string
/*init_method*/
)
{
double
rtol
=
1e-2
;
double
atol
=
1e-2
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
template
<
>
auto
get_elimit
<
ck_tile
::
fp8_t
>
(
std
::
string
init_method
)
{
if
(
init_method
==
"ui"
||
init_method
==
"ni"
)
{
unsigned
max_rounding_point_distance
=
0
;
double
atol
=
2e-3
;
return
ck_tile
::
make_tuple
(
max_rounding_point_distance
,
atol
);
}
else
{
unsigned
max_rounding_point_distance
=
1
;
double
atol
=
0.0625
;
return
ck_tile
::
make_tuple
(
max_rounding_point_distance
,
atol
);
}
}
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"v"
,
"1"
,
"weather do CPU validation or not"
)
.
insert
(
"input_prec"
,
"fp16"
,
"input data type. fp8/fp16/fp32 (representing 8/16/32 bit data)"
)
.
insert
(
"weight_prec"
,
"fp32"
,
"weight data type"
)
.
insert
(
"t"
,
"32"
,
"number of input tokens"
)
.
insert
(
"e"
,
"8"
,
"number of experts"
)
.
insert
(
"k"
,
"2"
,
"topk"
)
.
insert
(
"kname"
,
"0"
,
"t to 1 will print kernel name"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
template
<
typename
InputType
,
typename
WeightType
,
typename
IndexType
=
ck_tile
::
index_t
>
bool
test_topk_softmax
(
ck_tile
::
ArgParser
args
)
{
int
validate
=
args
.
get_int
(
"v"
);
std
::
string
input_prec
=
args
.
get_str
(
"input_prec"
);
std
::
string
weight_prec
=
args
.
get_str
(
"weight_prec"
);
int
tokens
=
args
.
get_int
(
"t"
);
int
experts
=
args
.
get_int
(
"e"
);
int
topk
=
args
.
get_int
(
"k"
);
// int kname = args.get_int("kname");
// int warmup = args.get_int("warmup");
// int repeat = args.get_int("repeat");
std
::
srand
(
std
::
time
(
nullptr
));
// tokens already considered batch size
ck_tile
::
HostTensor
<
InputType
>
x_host
({
tokens
,
experts
});
ck_tile
::
HostTensor
<
WeightType
>
value_host
({
tokens
,
topk
});
ck_tile
::
HostTensor
<
IndexType
>
index_host
({
tokens
,
topk
});
ck_tile
::
FillUniformDistribution
<
InputType
>
{
-
5.
f
,
5.
f
}(
x_host
);
ck_tile
::
DeviceMem
x_dev
(
x_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
value_dev
(
value_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
index_dev
(
index_host
.
get_element_space_size_in_bytes
());
x_dev
.
ToDevice
(
x_host
.
data
());
topk_softmax_trait
trait
=
[
&
]()
{
topk_softmax_trait
t_
;
t_
.
input_type
=
input_prec
;
t_
.
weight_type
=
weight_prec
;
t_
.
experts
=
experts
;
return
t_
;
}();
topk_softmax_kargs
karg
=
[
&
]()
{
topk_softmax_kargs
a_
;
a_
.
p_input
=
x_dev
.
GetDeviceBuffer
();
a_
.
p_output
=
value_dev
.
GetDeviceBuffer
();
a_
.
p_indices
=
index_dev
.
GetDeviceBuffer
();
a_
.
num_rows
=
tokens
;
a_
.
num_experts
=
experts
;
a_
.
topk
=
topk
;
return
a_
;
}();
ck_tile
::
stream_config
sc
{
nullptr
};
topk_softmax
(
trait
,
karg
,
sc
);
value_dev
.
FromDevice
(
value_host
.
data
());
index_dev
.
FromDevice
(
index_host
.
data
());
bool
rtn
=
true
;
if
(
validate
)
{
ck_tile
::
HostTensor
<
WeightType
>
value_host_ref
({
tokens
,
topk
});
ck_tile
::
HostTensor
<
IndexType
>
index_host_ref
({
tokens
,
topk
});
auto
[
value_ref
,
index_ref
]
=
reference_topk_softmax
<
InputType
,
WeightType
,
IndexType
>
(
x_host
,
topk
);
auto
[
rtol
,
atol
]
=
get_elimit
<
InputType
>
(
""
);
rtn
&=
ck_tile
::
check_err
(
value_host
,
value_ref
,
std
::
string
(
"Value Error: Incorrect results!"
),
rtol
,
atol
);
rtn
&=
ck_tile
::
check_err
(
index_host
,
index_ref
,
std
::
string
(
"Index Error: Incorrect results!"
),
rtol
,
atol
);
}
return
rtn
;
}
int
main
(
int
argc
,
char
**
argv
)
{
auto
[
result
,
args
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
std
::
string
input_prec
=
args
.
get_str
(
"input_prec"
);
std
::
string
weight_prec
=
args
.
get_str
(
"weight_prec"
);
bool
r
=
true
;
if
(
input_prec
.
compare
(
"fp16"
)
==
0
&&
weight_prec
.
compare
(
"fp32"
)
==
0
)
{
r
&=
test_topk_softmax
<
ck_tile
::
fp16_t
,
float
,
ck_tile
::
index_t
>
(
args
);
}
else
if
(
input_prec
.
compare
(
"bf16"
)
==
0
&&
weight_prec
.
compare
(
"fp32"
)
==
0
)
{
r
&=
test_topk_softmax
<
ck_tile
::
bf16_t
,
float
,
ck_tile
::
index_t
>
(
args
);
}
return
r
?
0
:
-
1
;
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment