Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
eed60199
Commit
eed60199
authored
Sep 13, 2024
by
carlushuang
Browse files
more robust api
parent
cae751d1
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
294 additions
and
112 deletions
+294
-112
include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp
include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp
+8
-7
include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp
...e/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp
+80
-31
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
...k_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
+69
-35
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp
...opk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp
+9
-5
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp
...pk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp
+7
-5
test/topk_softmax/script/smoke_test.sh
test/topk_softmax/script/smoke_test.sh
+22
-0
test/topk_softmax/topk_softmax.cpp
test/topk_softmax/topk_softmax.cpp
+99
-29
No files found.
include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp
View file @
eed60199
...
@@ -30,12 +30,13 @@ struct BlockTopkStream2D
...
@@ -30,12 +30,13 @@ struct BlockTopkStream2D
template
<
typename
DistributedTensor
,
typename
OutWindow
,
typename
IdxWindow
,
index_t
dim
=
1
>
template
<
typename
DistributedTensor
,
typename
OutWindow
,
typename
IdxWindow
,
index_t
dim
=
1
>
CK_TILE_DEVICE
void
operator
()(
const
DistributedTensor
&
x
,
CK_TILE_DEVICE
void
operator
()(
const
DistributedTensor
&
x
,
OutWindow
&
out_window
,
const
OutWindow
&
out_window
,
IdxWindow
&
idx_window
,
const
IdxWindow
&
idx_window
,
index_t
k
,
index_t
k
,
number
<
dim
>
=
{})
number
<
dim
>
=
{})
{
{
// static_assert(OutWindow::get_window_lengths()[number<1>] == 1);
OutWindow
out_window_tmp
=
out_window
;
IdxWindow
idx_window_tmp
=
idx_window
;
static_assert
(
static_assert
(
std
::
is_same_v
<
typename
DistributedTensor
::
DataType
,
typename
OutWindow
::
DataType
>
&&
std
::
is_same_v
<
typename
DistributedTensor
::
DataType
,
typename
OutWindow
::
DataType
>
&&
std
::
is_same_v
<
typename
DistributedTensor
::
DataType
,
DataType
>
);
std
::
is_same_v
<
typename
DistributedTensor
::
DataType
,
DataType
>
);
...
@@ -100,11 +101,11 @@ struct BlockTopkStream2D
...
@@ -100,11 +101,11 @@ struct BlockTopkStream2D
if
(
threadIdx
.
x
%
Problem
::
ColLanes
==
0
)
if
(
threadIdx
.
x
%
Problem
::
ColLanes
==
0
)
{
{
store_tile
(
out_window
,
o
);
store_tile
(
out_window
_tmp
,
o
);
store_tile
(
idx_window
,
i
);
store_tile
(
idx_window
_tmp
,
i
);
}
}
move_tile_window
(
out_window
,
{
number
<
0
>
{},
number
<
1
>
{}});
move_tile_window
(
out_window
_tmp
,
{
number
<
0
>
{},
number
<
1
>
{}});
move_tile_window
(
idx_window
,
{
number
<
0
>
{},
number
<
1
>
{}});
move_tile_window
(
idx_window
_tmp
,
{
number
<
0
>
{},
number
<
1
>
{}});
}
}
}
}
};
};
...
...
include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp
View file @
eed60199
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <string>
#include <string>
#include <type_traits>
#include <type_traits>
...
@@ -19,6 +20,8 @@ struct TopkSoftmaxHostArgs
...
@@ -19,6 +20,8 @@ struct TopkSoftmaxHostArgs
index_t
num_rows
;
index_t
num_rows
;
index_t
num_experts
;
index_t
num_experts
;
index_t
topk
;
index_t
topk
;
index_t
stride_input
;
// row stride for input, at least experts
index_t
stride_output
;
// row stride for output/indices, at least tpok
};
};
template
<
typename
Pipeline_
>
template
<
typename
Pipeline_
>
...
@@ -39,6 +42,8 @@ struct TopkSoftmaxKernel
...
@@ -39,6 +42,8 @@ struct TopkSoftmaxKernel
index_t
num_rows
;
index_t
num_rows
;
index_t
num_experts
;
index_t
num_experts
;
index_t
topk
;
index_t
topk
;
index_t
stride_input
;
// row stride for input, at least experts
index_t
stride_output
;
// row stride for output/indices, at least tpok
};
};
using
Kargs
=
TopkSoftmaxKargs
;
using
Kargs
=
TopkSoftmaxKargs
;
...
@@ -46,21 +51,37 @@ struct TopkSoftmaxKernel
...
@@ -46,21 +51,37 @@ struct TopkSoftmaxKernel
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
h
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
h
)
{
{
const
int
num_warps
=
(
h
.
num_rows
+
Problem
::
RowsPerWarp
-
1
)
/
Problem
::
RowsPerWarp
;
if
constexpr
(
Problem
::
LaunchType
>
0
)
const
int
num_blocks
=
(
num_warps
+
Problem
::
WarpsPerBlock
-
1
)
/
Problem
::
WarpsPerBlock
;
{
int
num_cu
=
[
&
]()
{
return
dim3
(
num_blocks
);
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
HIP_CHECK_ERROR
(
hipGetDevice
(
&
dev
));
HIP_CHECK_ERROR
(
hipGetDeviceProperties
(
&
dev_prop
,
dev
));
return
dev_prop
.
multiProcessorCount
;
}();
return
dim3
(
num_cu
*
Problem
::
LaunchType
);
}
else
{
const
int
num_warps
=
(
h
.
num_rows
+
Problem
::
RowsPerWarp
-
1
)
/
Problem
::
RowsPerWarp
;
const
int
num_blocks
=
(
num_warps
+
Problem
::
WarpsPerBlock
-
1
)
/
Problem
::
WarpsPerBlock
;
return
dim3
(
num_blocks
);
}
}
}
CK_TILE_HOST
static
constexpr
auto
MakeKargs
(
const
Hargs
&
h
)
CK_TILE_HOST
static
constexpr
auto
MakeKargs
(
const
Hargs
&
h
)
{
{
Kargs
k
;
Kargs
k
;
k
.
p_input
=
h
.
p_input
;
k
.
p_input
=
h
.
p_input
;
k
.
p_output
=
h
.
p_output
;
k
.
p_output
=
h
.
p_output
;
k
.
p_indices
=
h
.
p_indices
;
k
.
p_indices
=
h
.
p_indices
;
k
.
num_rows
=
h
.
num_rows
;
k
.
num_rows
=
h
.
num_rows
;
k
.
num_experts
=
h
.
num_experts
;
k
.
num_experts
=
h
.
num_experts
;
k
.
topk
=
h
.
topk
;
k
.
topk
=
h
.
topk
;
k
.
stride_input
=
h
.
stride_input
;
k
.
stride_output
=
h
.
stride_output
;
return
k
;
return
k
;
}
}
...
@@ -68,19 +89,30 @@ struct TopkSoftmaxKernel
...
@@ -68,19 +89,30 @@ struct TopkSoftmaxKernel
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
{
index_t
block_row_id
=
static_cast
<
index_t
>
(
blockIdx
.
x
*
Problem
::
RowsPerBlock
);
index_t
block_row_id
=
static_cast
<
index_t
>
(
blockIdx
.
x
*
Problem
::
RowsPerBlock
);
if
(
block_row_id
>
kargs
.
num_rows
)
return
;
index_t
block_os_inp
=
__builtin_amdgcn_readfirstlane
(
block_row_id
*
kargs
.
stride_input
);
index_t
block_os_out
=
__builtin_amdgcn_readfirstlane
(
block_row_id
*
kargs
.
stride_output
);
index_t
num_rows_rem
=
__builtin_amdgcn_readfirstlane
(
kargs
.
num_rows
-
block_row_id
);
const
auto
input_window
=
[
&
]()
{
const
auto
input_window
=
[
&
]()
{
const
InputType
*
p_input
=
reinterpret_cast
<
const
InputType
*>
(
kargs
.
p_input
)
+
const
InputType
*
p_input
=
block_row_id
*
kargs
.
num_experts
;
reinterpret_cast
<
const
InputType
*>
(
kargs
.
p_input
)
+
block_os_inp
;
auto
tmp
=
make_naive_tensor_view
_packed
<
address_space_enum
::
global
>
(
auto
tmp
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_input
,
p_input
,
make_tuple
(
kargs
.
num_rows
,
kargs
.
num_experts
),
make_tuple
(
num_rows_rem
,
kargs
.
num_experts
),
number
<
Problem
::
VectorSize
>
{});
make_tuple
(
kargs
.
stride_input
,
1
),
number
<
Problem
::
VectorSize
>
{},
number
<
1
>
{});
auto
view
=
pad_tensor_view
(
auto
view
=
pad_tensor_view
(
tmp
,
tmp
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
Problem
::
Experts
>
{}),
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
Problem
::
Experts
>
{}),
sequence
<
1
,
1
>
{});
sequence
<
0
,
1
>
{});
// out-most dim no need pad(leverage oob)
return
make_tile_window
(
return
make_tile_window
(
view
,
view
,
...
@@ -89,29 +121,46 @@ struct TopkSoftmaxKernel
...
@@ -89,29 +121,46 @@ struct TopkSoftmaxKernel
}();
}();
auto
output_window
=
[
&
]()
{
auto
output_window
=
[
&
]()
{
WeightType
*
p_output
=
WeightType
*
p_output
=
reinterpret_cast
<
WeightType
*>
(
kargs
.
p_output
)
+
block_os_out
;
reinterpret_cast
<
WeightType
*>
(
kargs
.
p_output
)
+
block_row_id
*
kargs
.
topk
;
auto
tmp
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
auto
tmp
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
p_output
,
p_output
,
make_tuple
(
kargs
.
num_rows
,
kargs
.
topk
),
number
<
Problem
::
VectorSize
>
{});
make_tuple
(
num_rows_rem
,
kargs
.
topk
),
auto
view
=
pad_tensor_view
(
make_tuple
(
kargs
.
stride_output
,
1
),
tmp
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
sequence
<
1
,
0
>
{});
number
<
Problem
::
VectorSize
>
{},
number
<
1
>
{});
auto
view
=
pad_tensor_view
(
tmp
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
sequence
<
0
,
0
>
{});
// 1. out-most dim no need pad(leverage oob)
// 2. we loop over topk 1-1, no need padding
return
make_tile_window
(
return
make_tile_window
(
view
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
{
0
,
0
});
view
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
{
0
,
0
});
}();
}();
auto
indices_window
=
[
&
]()
{
auto
indices_window
=
[
&
]()
{
IndexType
*
p_indices
=
IndexType
*
p_indices
=
reinterpret_cast
<
IndexType
*>
(
kargs
.
p_indices
)
+
block_os_out
;
reinterpret_cast
<
IndexType
*>
(
kargs
.
p_indices
)
+
block_row_id
*
kargs
.
topk
;
auto
tmp
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
auto
tmp
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
p_indices
,
p_indices
,
make_tuple
(
kargs
.
num_rows
,
kargs
.
topk
),
number
<
Problem
::
VectorSize
>
{});
make_tuple
(
num_rows_rem
,
kargs
.
topk
),
auto
view
=
pad_tensor_view
(
make_tuple
(
kargs
.
stride_output
,
1
),
tmp
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
sequence
<
1
,
0
>
{});
number
<
Problem
::
VectorSize
>
{},
number
<
1
>
{});
auto
view
=
pad_tensor_view
(
tmp
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
sequence
<
0
,
0
>
{});
// 1. out-most dim no need pad(leverage oob)
// 2. we loop over topk 1-1, no need padding
return
make_tile_window
(
return
make_tile_window
(
view
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
{
0
,
0
});
view
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
{
0
,
0
});
}();
}();
Pipeline
{}(
input_window
,
output_window
,
indices_window
,
kargs
.
topk
,
kargs
.
num_experts
);
Pipeline
{}(
input_window
,
output_window
,
indices_window
,
kargs
.
num_rows
,
kargs
.
num_experts
,
kargs
.
topk
,
block_row_id
);
}
}
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
View file @
eed60199
...
@@ -8,6 +8,10 @@
...
@@ -8,6 +8,10 @@
#include <string>
#include <string>
#include <type_traits>
#include <type_traits>
#ifndef TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
#define TOPK_SOFTMAX_USE_RAW_TILE_WINDOW 0
#endif
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
TopkSoftmaxWarpPerRowPolicy
>
template
<
typename
Problem_
,
typename
Policy_
=
TopkSoftmaxWarpPerRowPolicy
>
...
@@ -22,42 +26,18 @@ struct TopkSoftmaxWarpPerRowPipeline
...
@@ -22,42 +26,18 @@ struct TopkSoftmaxWarpPerRowPipeline
CK_TILE_DEVICE
auto
operator
()(
const
InputWindow
&
input_window
,
CK_TILE_DEVICE
auto
operator
()(
const
InputWindow
&
input_window
,
OutputWindow
&
out_window
,
OutputWindow
&
out_window
,
IndexWindow
&
idx_window
,
IndexWindow
&
idx_window
,
index_t
rows
,
index_t
experts
,
index_t
k
,
index_t
k
,
index_t
experts
)
index_t
block_row_id
)
{
{
auto
input_win
=
make_tile_window
(
input_window
.
get_bottom_tensor_view
(),
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
input_window
.
get_window_lengths
(),
auto
inp_win
=
make_tile_window_linear_raw
(
input_window
.
get_window_origin
(),
input_window
,
Policy
::
template
MakeInputDistribution
<
Problem
>(),
sequence
<
0
,
1
>
{});
Policy
::
template
MakeInputDistribution
<
Problem
>());
#else
auto
inp_win
=
make_tile_window_linear
(
auto
x
=
load_tile
(
input_win
);
input_window
,
Policy
::
template
MakeInputDistribution
<
Problem
>(),
sequence
<
0
,
1
>
{});
#endif
// cast and pad input data
auto
w
=
[
&
]()
{
auto
w_
=
cast_tile
<
WeightType
>
(
x
);
constexpr
auto
span_2d
=
decltype
(
w_
)
::
get_distributed_spans
();
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
x_indices
=
get_x_indices_from_distributed_indices
(
w_
.
get_tile_distribution
(),
i_j_idx
);
const
auto
current_expert
=
x_indices
.
at
(
number
<
1
>
{});
// set to -INF if OOB so that later softmax can work properly
w_
(
i_j_idx
)
=
current_expert
>=
experts
?
-
numeric
<
WeightType
>::
infinity
()
:
w_
(
i_j_idx
);
});
});
return
w_
;
}();
auto
softmax
=
Policy
::
template
GetSoftmax
<
Problem
>();
// softmax
auto
y
=
softmax
(
w
);
auto
topk
=
Policy
::
template
GetTopk
<
Problem
>();
auto
out_win
=
make_tile_window
(
out_window
.
get_bottom_tensor_view
(),
auto
out_win
=
make_tile_window
(
out_window
.
get_bottom_tensor_view
(),
out_window
.
get_window_lengths
(),
out_window
.
get_window_lengths
(),
out_window
.
get_window_origin
(),
out_window
.
get_window_origin
(),
...
@@ -67,7 +47,61 @@ struct TopkSoftmaxWarpPerRowPipeline
...
@@ -67,7 +47,61 @@ struct TopkSoftmaxWarpPerRowPipeline
idx_window
.
get_window_origin
(),
idx_window
.
get_window_origin
(),
Policy
::
template
MakeOutputDistribution
<
Problem
>());
Policy
::
template
MakeOutputDistribution
<
Problem
>());
topk
(
y
,
out_win
,
idx_win
,
k
);
auto
softmax
=
Policy
::
template
GetSoftmax
<
Problem
>();
auto
topk
=
Policy
::
template
GetTopk
<
Problem
>();
const
index_t
grid_rows_per_loop
=
gridDim
.
x
*
Problem
::
RowsPerBlock
;
while
(
1
)
{
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
__builtin_amdgcn_sched_barrier
(
0
);
auto
x
=
load_tile_raw
(
inp_win
,
bool_constant
<
true
>
{},
bool_constant
<
true
>
{});
buffer_load_fence
(
number
<
0
>
{});
__builtin_amdgcn_sched_barrier
(
0
);
#else
auto
x
=
load_tile
(
inp_win
);
#endif
// cast and pad input data
auto
w
=
[
&
]()
{
auto
w_
=
cast_tile
<
WeightType
>
(
x
);
constexpr
auto
span_2d
=
decltype
(
w_
)
::
get_distributed_spans
();
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
x_indices
=
get_x_indices_from_distributed_indices
(
w_
.
get_tile_distribution
(),
i_j_idx
);
const
auto
current_expert
=
x_indices
.
at
(
number
<
1
>
{});
// set to -INF if OOB so that later softmax can work properly
w_
(
i_j_idx
)
=
current_expert
>=
experts
?
-
numeric
<
WeightType
>::
infinity
()
:
w_
(
i_j_idx
);
});
});
return
w_
;
}();
// softmax
auto
y
=
softmax
(
w
);
topk
(
y
,
out_win
,
idx_win
,
k
);
// check exit
if
constexpr
(
Problem
::
LaunchType
==
0
)
{
break
;
}
else
{
block_row_id
+=
grid_rows_per_loop
;
if
(
block_row_id
>=
rows
)
break
;
}
move_tile_window
(
inp_win
,
{
grid_rows_per_loop
,
number
<
0
>
{}});
move_tile_window
(
out_win
,
{
grid_rows_per_loop
,
number
<
0
>
{}});
move_tile_window
(
idx_win
,
{
grid_rows_per_loop
,
number
<
0
>
{}});
}
}
}
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp
View file @
eed60199
...
@@ -18,7 +18,9 @@ struct TopkSoftmaxWarpPerRowPolicy
...
@@ -18,7 +18,9 @@ struct TopkSoftmaxWarpPerRowPolicy
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
tile_distribution_encoding
<
sequence
<
1
>
,
sequence
<
1
>
,
tuple
<
sequence
<
Problem
::
IssuesPerCol
,
Problem
::
WarpsPerBlock
,
Problem
::
RowsPerWarp
>
,
tuple
<
sequence
<
Problem
::
IssuesPerCol
,
Problem
::
WarpsPerBlock
,
Problem
::
RowsPerWarpPerColIssue
>
,
sequence
<
Problem
::
IssuesPerRow
,
Problem
::
LanesPerRow
,
Problem
::
VectorSize
>>
,
sequence
<
Problem
::
IssuesPerRow
,
Problem
::
LanesPerRow
,
Problem
::
VectorSize
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
1
>>
,
...
@@ -31,12 +33,14 @@ struct TopkSoftmaxWarpPerRowPolicy
...
@@ -31,12 +33,14 @@ struct TopkSoftmaxWarpPerRowPolicy
{
{
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
Problem
::
LanesPerRow
>
,
// repeat this one
tile_distribution_encoding
<
sequence
<
Problem
::
LanesPerRow
>
,
// repeat this one
tuple
<
sequence
<
Problem
::
WarpsPerBlock
,
Problem
::
RowsPerWarp
>
,
tuple
<
sequence
<
Problem
::
IssuesPerCol
,
Problem
::
WarpsPerBlock
,
Problem
::
RowsPerWarpPerColIssue
>
,
sequence
<
1
>>
,
// each row write out single element
sequence
<
1
>>
,
// each row write out single element
tuple
<
sequence
<
1
>
,
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
,
2
>
,
sequence
<
0
>>
{});
sequence
<
0
,
0
>>
{});
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
...
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp
View file @
eed60199
...
@@ -13,8 +13,9 @@ template <typename InputType_,
...
@@ -13,8 +13,9 @@ template <typename InputType_,
typename
WeightType_
,
typename
WeightType_
,
typename
IndexType_
,
typename
IndexType_
,
index_t
Experts_
,
index_t
Experts_
,
index_t
IssuesPerCol_
=
1
,
// issue along col, to make sure block_reduce() OK
index_t
IssuesPerCol_
=
2
,
// issue along col, to make sure block_reduce() OK
index_t
BytesPerIssue_
=
sizeof
(
InputType_
),
index_t
BytesPerIssue_
=
sizeof
(
InputType_
),
index_t
LaunchType_
=
0
,
// 0-streaming, >0, persistent #occupancy
index_t
BlockSize_
=
256
>
index_t
BlockSize_
=
256
>
struct
TopkSoftmaxWarpPerRowProblem
struct
TopkSoftmaxWarpPerRowProblem
{
{
...
@@ -23,8 +24,10 @@ struct TopkSoftmaxWarpPerRowProblem
...
@@ -23,8 +24,10 @@ struct TopkSoftmaxWarpPerRowProblem
using
WeightType
=
remove_cvref_t
<
WeightType_
>
;
using
WeightType
=
remove_cvref_t
<
WeightType_
>
;
using
IndexType
=
remove_cvref_t
<
IndexType_
>
;
using
IndexType
=
remove_cvref_t
<
IndexType_
>
;
static
constexpr
index_t
LaunchType
=
LaunchType_
;
static
constexpr
index_t
Experts
=
Experts_
;
static
constexpr
index_t
Experts
=
Experts_
;
static
constexpr
index_t
BytesPerIssue
=
BytesPerIssue_
;
static
constexpr
index_t
BytesPerIssue
=
BytesPerIssue_
;
static
constexpr
index_t
IssuesPerCol
=
IssuesPerCol_
;
static
constexpr
index_t
BlockSize
=
BlockSize_
;
static
constexpr
index_t
BlockSize
=
BlockSize_
;
static
constexpr
index_t
WarpSize
=
get_warp_size
();
static
constexpr
index_t
WarpSize
=
get_warp_size
();
...
@@ -33,10 +36,9 @@ struct TopkSoftmaxWarpPerRowProblem
...
@@ -33,10 +36,9 @@ struct TopkSoftmaxWarpPerRowProblem
static_assert
(
Experts
%
VectorSize
==
0
);
static_assert
(
Experts
%
VectorSize
==
0
);
static
constexpr
index_t
LanesPerRow
=
min
(
Experts
/
VectorSize
,
WarpSize
);
static
constexpr
index_t
LanesPerRow
=
min
(
Experts
/
VectorSize
,
WarpSize
);
static_assert
(
WarpSize
%
LanesPerRow
==
0
);
static_assert
(
WarpSize
%
LanesPerRow
==
0
);
static
constexpr
index_t
RowsPerWarp
=
WarpSize
/
LanesPerRow
;
static
constexpr
index_t
RowsPerWarpPerColIssue
=
WarpSize
/
LanesPerRow
;
static
constexpr
index_t
IssuesPerRow
=
Experts
/
(
LanesPerRow
*
VectorSize
);
static
constexpr
index_t
RowsPerWarp
=
IssuesPerCol
*
RowsPerWarpPerColIssue
;
static
constexpr
index_t
IssuesPerRow
=
Experts
/
(
LanesPerRow
*
VectorSize
);
static
constexpr
index_t
IssuesPerCol
=
IssuesPerCol_
;
static
constexpr
index_t
WarpsPerBlock
=
BlockSize
/
WarpSize
;
static
constexpr
index_t
WarpsPerBlock
=
BlockSize
/
WarpSize
;
static
constexpr
index_t
RowsPerBlock
=
RowsPerWarp
*
WarpsPerBlock
;
static
constexpr
index_t
RowsPerBlock
=
RowsPerWarp
*
WarpsPerBlock
;
...
...
test/topk_softmax/script/smoke_test.sh
0 → 100644
View file @
eed60199
#!/bin/sh
EXE
=
./build/bin/test_topk_softmax
for
pr_i
in
"fp16"
"bf16"
;
do
$EXE
-pr_i
=
$pr_i
-t
=
80
-e
=
17
$EXE
-pr_i
=
$pr_i
-t
=
111
-e
=
117
$EXE
-pr_i
=
$pr_i
-t
=
1000
-e
=
55
$EXE
-pr_i
=
$pr_i
-t
=
99
-e
=
180
$EXE
-pr_i
=
$pr_i
-t
=
175
-e
=
64
-k
=
8
$EXE
-pr_i
=
$pr_i
-t
=
65
-e
=
8
-k
=
2
$EXE
-pr_i
=
$pr_i
-t
=
1
-e
=
25
$EXE
-pr_i
=
$pr_i
-t
=
31
-e
=
19
-k
=
15
$EXE
-pr_i
=
$pr_i
-t
=
81
-e
=
37
-k
=
7
$EXE
-pr_i
=
$pr_i
-t
=
199
-e
=
128
-k
=
13
$EXE
-pr_i
=
$pr_i
-t
=
23
-e
=
1
-k
=
1
$EXE
-pr_i
=
$pr_i
-t
=
127
-e
=
99
-k
=
19
-st_i
=
233
-st_o
=
31
$EXE
-pr_i
=
$pr_i
-t
=
71
-e
=
11
-k
=
11
-st_i
=
30
-st_o
=
12
$EXE
-pr_i
=
$pr_i
-t
=
1
-e
=
1
-k
=
1
$EXE
-pr_i
=
$pr_i
-t
=
99
-e
=
2
-k
=
1
-st_i
=
11
-st_o
=
5
$EXE
-pr_i
=
$pr_i
-t
=
333
-e
=
99
-k
=
13
-st_i
=
191
-st_o
=
17
done
test/topk_softmax/topk_softmax.cpp
View file @
eed60199
...
@@ -18,6 +18,11 @@
...
@@ -18,6 +18,11 @@
#define TEST_TOPK_SOFTMAX_VERBOSE 1
#define TEST_TOPK_SOFTMAX_VERBOSE 1
#endif
#endif
// set this to 1 if input/output have stride
#ifndef TEST_TOPK_VERIFY_PER_TOKEN
#define TEST_TOPK_VERIFY_PER_TOKEN 1
#endif
template
<
typename
T
>
template
<
typename
T
>
void
dump_host_tensor_2d
(
const
ck_tile
::
HostTensor
<
T
>&
x
)
void
dump_host_tensor_2d
(
const
ck_tile
::
HostTensor
<
T
>&
x
)
{
{
...
@@ -62,19 +67,32 @@ auto reference_topk_softmax(const ck_tile::HostTensor<InputType>& x,
...
@@ -62,19 +67,32 @@ auto reference_topk_softmax(const ck_tile::HostTensor<InputType>& x,
{
{
using
namespace
ck_tile
;
using
namespace
ck_tile
;
// dump_host_tensor_2d(x);
auto
y
=
reference_softmax
<
InputType
,
WeightType
,
WeightType
>
(
x
,
dim
);
auto
y
=
reference_softmax
<
InputType
,
WeightType
,
WeightType
>
(
x
,
dim
);
// dump_host_tensor_2d(y);
auto
[
y_values
,
y_indices
]
=
reference_topk
(
y
,
k
,
dim
,
largest
,
sorted
);
auto
[
y_values
,
y_indices
]
=
reference_topk
(
y
,
k
,
dim
,
largest
,
sorted
);
// dump_host_tensor_2d(y_values);
// dump_host_tensor_2d(y_indices);
return
ck_tile
::
make_tuple
(
y_values
,
y_indices
);
return
ck_tile
::
make_tuple
(
y_values
,
y_indices
);
}
}
template
<
typename
InputType
,
typename
WeightType
,
typename
IndexType
=
ck_tile
::
index_t
>
auto
reference_topk_softmax
(
const
ck_tile
::
HostTensor
<
InputType
>&
x
,
ck_tile
::
HostTensor
<
WeightType
>&
y_values
,
ck_tile
::
HostTensor
<
IndexType
>&
y_indices
,
ck_tile
::
index_t
k
,
ck_tile
::
index_t
dim
=
-
1
,
bool
largest
=
true
,
bool
sorted
=
true
)
{
using
namespace
ck_tile
;
// dump_host_tensor_2d(x);
auto
y
=
reference_softmax
<
InputType
,
WeightType
,
WeightType
>
(
x
,
dim
);
// dump_host_tensor_2d(y);
reference_topk
(
y
,
y_values
,
y_indices
,
k
,
dim
,
largest
,
sorted
);
}
// different threshold for different dtype
// different threshold for different dtype
template
<
typename
DataType
>
template
<
typename
DataType
>
auto
get_elimit
(
std
::
string
/*init_method*/
)
auto
get_elimit
(
std
::
string
/*init_method*/
)
...
@@ -113,12 +131,13 @@ auto create_args(int argc, char* argv[])
...
@@ -113,12 +131,13 @@ auto create_args(int argc, char* argv[])
{
{
ck_tile
::
ArgParser
arg_parser
;
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"v"
,
"1"
,
"weather do CPU validation or not"
)
arg_parser
.
insert
(
"v"
,
"1"
,
"weather do CPU validation or not"
)
.
insert
(
.
insert
(
"pr_i"
,
"fp16"
,
"input data type. fp16/fp32 (representing 8/16/32 bit data)"
)
"input_prec"
,
"fp16"
,
"input data type. fp8/fp16/fp32 (representing 8/16/32 bit data)"
)
.
insert
(
"pr_w"
,
"fp32"
,
"weight data type(currently only fp32 supported now)"
)
.
insert
(
"weight_prec"
,
"fp32"
,
"weight data type"
)
.
insert
(
"t"
,
"32"
,
"number of input tokens"
)
.
insert
(
"t"
,
"32"
,
"number of input tokens"
)
.
insert
(
"e"
,
"8"
,
"number of experts"
)
.
insert
(
"e"
,
"8"
,
"number of experts"
)
.
insert
(
"k"
,
"2"
,
"topk"
)
.
insert
(
"k"
,
"2"
,
"topk"
)
.
insert
(
"st_i"
,
"-1"
,
"row stride of input, -1 means same as experts"
)
.
insert
(
"st_o"
,
"-1"
,
"row stride of output/indices, -1 means same as topk"
)
.
insert
(
"seed"
,
"-1"
,
"seed to be used, -1 means random every time"
)
.
insert
(
"seed"
,
"-1"
,
"seed to be used, -1 means random every time"
)
.
insert
(
"kname"
,
"0"
,
"t to 1 will print kernel name"
);
.
insert
(
"kname"
,
"0"
,
"t to 1 will print kernel name"
);
...
@@ -130,12 +149,25 @@ template <typename InputType, typename WeightType, typename IndexType = ck_tile:
...
@@ -130,12 +149,25 @@ template <typename InputType, typename WeightType, typename IndexType = ck_tile:
bool
test_topk_softmax
(
ck_tile
::
ArgParser
args
)
bool
test_topk_softmax
(
ck_tile
::
ArgParser
args
)
{
{
int
validate
=
args
.
get_int
(
"v"
);
int
validate
=
args
.
get_int
(
"v"
);
std
::
string
input_prec
=
args
.
get_str
(
"
input_prec
"
);
std
::
string
input_prec
=
args
.
get_str
(
"
pr_i
"
);
std
::
string
weight_prec
=
args
.
get_str
(
"
weight_prec
"
);
std
::
string
weight_prec
=
args
.
get_str
(
"
pr_w
"
);
int
tokens
=
args
.
get_int
(
"t"
);
int
tokens
=
args
.
get_int
(
"t"
);
int
experts
=
args
.
get_int
(
"e"
);
int
experts
=
args
.
get_int
(
"e"
);
int
topk
=
args
.
get_int
(
"k"
);
int
topk
=
args
.
get_int
(
"k"
);
int
seed
=
args
.
get_int
(
"seed"
);
int
seed
=
args
.
get_int
(
"seed"
);
int
stride_input
=
args
.
get_int
(
"st_i"
);
int
stride_output
=
args
.
get_int
(
"st_o"
);
if
(
stride_input
<
0
)
{
stride_input
=
experts
;
}
if
(
stride_output
<
0
)
{
stride_output
=
topk
;
}
assert
(
stride_input
>=
experts
);
assert
(
stride_output
>=
topk
);
if
(
seed
<
0
)
if
(
seed
<
0
)
{
{
seed
=
std
::
time
(
nullptr
);
seed
=
std
::
time
(
nullptr
);
...
@@ -153,9 +185,9 @@ bool test_topk_softmax(ck_tile::ArgParser args)
...
@@ -153,9 +185,9 @@ bool test_topk_softmax(ck_tile::ArgParser args)
}
}
// tokens already considered batch size
// tokens already considered batch size
ck_tile
::
HostTensor
<
InputType
>
x_host
({
tokens
,
experts
});
ck_tile
::
HostTensor
<
InputType
>
x_host
({
tokens
,
experts
}
,
{
stride_input
,
1
}
);
ck_tile
::
HostTensor
<
WeightType
>
value_host
({
tokens
,
topk
});
ck_tile
::
HostTensor
<
WeightType
>
value_host
({
tokens
,
topk
}
,
{
stride_output
,
1
}
);
ck_tile
::
HostTensor
<
IndexType
>
index_host
({
tokens
,
topk
});
ck_tile
::
HostTensor
<
IndexType
>
index_host
({
tokens
,
topk
}
,
{
stride_output
,
1
}
);
{
{
// random require per-row unique
// random require per-row unique
...
@@ -166,7 +198,7 @@ bool test_topk_softmax(ck_tile::ArgParser args)
...
@@ -166,7 +198,7 @@ bool test_topk_softmax(ck_tile::ArgParser args)
{
{
ck_tile
::
HostTensor
<
InputType
>
x_row
({
experts
});
ck_tile
::
HostTensor
<
InputType
>
x_row
({
experts
});
rand_gen
(
x_row
);
rand_gen
(
x_row
);
std
::
copy
(
x_row
.
begin
(),
x_row
.
end
(),
x_host
.
begin
()
+
i_t
*
experts
);
std
::
copy
(
x_row
.
begin
(),
x_row
.
end
(),
x_host
.
begin
()
+
i_t
*
stride_input
);
rand_gen
.
clear
();
rand_gen
.
clear
();
}
}
}
}
...
@@ -187,30 +219,41 @@ bool test_topk_softmax(ck_tile::ArgParser args)
...
@@ -187,30 +219,41 @@ bool test_topk_softmax(ck_tile::ArgParser args)
topk_softmax_kargs
karg
=
[
&
]()
{
topk_softmax_kargs
karg
=
[
&
]()
{
topk_softmax_kargs
a_
;
topk_softmax_kargs
a_
;
a_
.
p_input
=
x_dev
.
GetDeviceBuffer
();
a_
.
p_input
=
x_dev
.
GetDeviceBuffer
();
a_
.
p_output
=
value_dev
.
GetDeviceBuffer
();
a_
.
p_output
=
value_dev
.
GetDeviceBuffer
();
a_
.
p_indices
=
index_dev
.
GetDeviceBuffer
();
a_
.
p_indices
=
index_dev
.
GetDeviceBuffer
();
a_
.
num_rows
=
tokens
;
a_
.
num_rows
=
tokens
;
a_
.
num_experts
=
experts
;
a_
.
num_experts
=
experts
;
a_
.
topk
=
topk
;
a_
.
topk
=
topk
;
a_
.
stride_input
=
stride_input
;
a_
.
stride_output
=
stride_output
;
return
a_
;
return
a_
;
}();
}();
#if TEST_TOPK_SOFTMAX_VERBOSE
#if TEST_TOPK_SOFTMAX_VERBOSE
ck_tile
::
stream_config
sc
{
nullptr
,
true
};
ck_tile
::
stream_config
sc
{
nullptr
,
true
};
// ck_tile::stream_config sc{nullptr};
auto
ms
=
topk_softmax
(
trait
,
karg
,
sc
);
auto
ms
=
topk_softmax
(
trait
,
karg
,
sc
);
printf
(
"[%s|%s]tokens:%d, experts:%d, topk:%d, ms:%f, "
,
printf
(
"[%s|%s]tokens:%d, experts:%d, topk:%d,
st_i:%d, st_o:%d,
ms:%f, "
,
input_prec
.
c_str
(),
input_prec
.
c_str
(),
weight_prec
.
c_str
(),
weight_prec
.
c_str
(),
tokens
,
tokens
,
experts
,
experts
,
topk
,
topk
,
stride_input
,
stride_output
,
ms
);
ms
);
if
(
ms
<
0
)
printf
(
"not supported
\n
"
);
fflush
(
stdout
);
fflush
(
stdout
);
#else
#else
ck_tile
::
stream_config
sc
{
nullptr
};
ck_tile
::
stream_config
sc
{
nullptr
};
topk_softmax
(
trait
,
karg
,
sc
);
auto
ms
=
topk_softmax
(
trait
,
karg
,
sc
);
#endif
#endif
if
(
ms
<
0
)
{
return
false
;
}
value_dev
.
FromDevice
(
value_host
.
data
());
value_dev
.
FromDevice
(
value_host
.
data
());
index_dev
.
FromDevice
(
index_host
.
data
());
index_dev
.
FromDevice
(
index_host
.
data
());
...
@@ -218,17 +261,44 @@ bool test_topk_softmax(ck_tile::ArgParser args)
...
@@ -218,17 +261,44 @@ bool test_topk_softmax(ck_tile::ArgParser args)
bool
rtn
=
true
;
bool
rtn
=
true
;
if
(
validate
)
if
(
validate
)
{
{
ck_tile
::
HostTensor
<
WeightType
>
value_host_ref
({
tokens
,
topk
});
// this host buffer will not copy to GPU, so no need use stride
ck_tile
::
HostTensor
<
IndexType
>
index_host_ref
({
tokens
,
topk
});
ck_tile
::
HostTensor
<
WeightType
>
value_ref
({
tokens
,
topk
},
{
stride_output
,
1
});
ck_tile
::
HostTensor
<
IndexType
>
index_ref
({
tokens
,
topk
},
{
stride_output
,
1
});
auto
[
value_ref
,
index_ref
]
=
// auto [value_ref, index_ref] =
reference_topk_softmax
<
InputType
,
WeightType
,
IndexType
>
(
x_host
,
topk
);
reference_topk_softmax
<
InputType
,
WeightType
,
IndexType
>
(
x_host
,
value_ref
,
index_ref
,
topk
);
auto
[
rtol
,
atol
]
=
get_elimit
<
InputType
>
(
""
);
auto
[
rtol
,
atol
]
=
get_elimit
<
InputType
>
(
""
);
#if TEST_TOPK_VERIFY_PER_TOKEN
for
(
int
i_t
=
0
;
i_t
<
tokens
;
i_t
++
)
{
auto
s_begin
=
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
i_t
),
static_cast
<
size_t
>
(
0
)};
auto
s_end
=
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
i_t
+
1
),
static_cast
<
size_t
>
(
topk
)};
auto
s_value_host
=
value_host
.
slice
(
s_begin
,
s_end
);
auto
s_value_ref
=
value_ref
.
slice
(
s_begin
,
s_end
);
rtn
&=
ck_tile
::
check_err
(
s_value_host
,
s_value_ref
,
std
::
string
(
"["
)
+
std
::
to_string
(
i_t
)
+
std
::
string
(
"] Value Error:"
),
rtol
,
atol
);
auto
s_index_host
=
index_host
.
slice
(
s_begin
,
s_end
);
auto
s_index_ref
=
index_ref
.
slice
(
s_begin
,
s_end
);
rtn
&=
ck_tile
::
check_err
(
s_index_host
,
s_index_ref
,
std
::
string
(
"["
)
+
std
::
to_string
(
i_t
)
+
std
::
string
(
"] Index Error:"
),
rtol
,
atol
);
}
#else
rtn
&=
ck_tile
::
check_err
(
rtn
&=
ck_tile
::
check_err
(
value_host
,
value_ref
,
std
::
string
(
"Value Error: Incorrect results!"
),
rtol
,
atol
);
value_host
,
value_ref
,
std
::
string
(
"Value Error: Incorrect results!"
),
rtol
,
atol
);
rtn
&=
ck_tile
::
check_err
(
rtn
&=
ck_tile
::
check_err
(
index_host
,
index_ref
,
std
::
string
(
"Index Error: Incorrect results!"
),
rtol
,
atol
);
index_host
,
index_ref
,
std
::
string
(
"Index Error: Incorrect results!"
),
rtol
,
atol
);
#endif
}
}
#if TEST_TOPK_SOFTMAX_VERBOSE
#if TEST_TOPK_SOFTMAX_VERBOSE
printf
(
"valid:%s
\n
"
,
rtn
?
"y"
:
"n"
);
printf
(
"valid:%s
\n
"
,
rtn
?
"y"
:
"n"
);
...
@@ -242,8 +312,8 @@ int main(int argc, char** argv)
...
@@ -242,8 +312,8 @@ int main(int argc, char** argv)
auto
[
result
,
args
]
=
create_args
(
argc
,
argv
);
auto
[
result
,
args
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
if
(
!
result
)
return
-
1
;
return
-
1
;
std
::
string
input_prec
=
args
.
get_str
(
"
input_prec
"
);
std
::
string
input_prec
=
args
.
get_str
(
"
pr_i
"
);
std
::
string
weight_prec
=
args
.
get_str
(
"
weight_prec
"
);
std
::
string
weight_prec
=
args
.
get_str
(
"
pr_w
"
);
bool
r
=
true
;
bool
r
=
true
;
if
(
input_prec
.
compare
(
"fp16"
)
==
0
&&
weight_prec
.
compare
(
"fp32"
)
==
0
)
if
(
input_prec
.
compare
(
"fp16"
)
==
0
&&
weight_prec
.
compare
(
"fp32"
)
==
0
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment