Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a4140bb9
"include/exceptions.h" did not exist on "81ff4946ae383e21c54ba306bf89f85d9ae371b2"
Commit
a4140bb9
authored
Nov 02, 2024
by
carlushuang
Browse files
use magiv div to accelerate compute
parent
a61ccfe8
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
102 additions
and
45 deletions
+102
-45
example/ck_tile/13_moe_sorting/moe_sorting.cpp
example/ck_tile/13_moe_sorting/moe_sorting.cpp
+10
-10
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
+7
-9
example/ck_tile/13_moe_sorting/moe_sorting_api.hpp
example/ck_tile/13_moe_sorting/moe_sorting_api.hpp
+2
-2
include/ck_tile/ops/moe_sorting.hpp
include/ck_tile/ops/moe_sorting.hpp
+1
-0
include/ck_tile/ops/moe_sorting/kernel/moe_sorting_kernel.hpp
...ude/ck_tile/ops/moe_sorting/kernel/moe_sorting_kernel.hpp
+82
-24
No files found.
example/ck_tile/13_moe_sorting/moe_sorting.cpp
View file @
a4140bb9
...
...
@@ -110,7 +110,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
moe_sorting_trait
trait
{
index_prec
,
weight_prec
,
experts
,
topk
,
unit_size
,
tokens
};
moe_sorting_
k
args
karg
{
topk_ids_dev
.
GetDeviceBuffer
(),
moe_sorting_args
karg
{
topk_ids_dev
.
GetDeviceBuffer
(),
weights_dev
.
GetDeviceBuffer
(),
sorted_ids_dev
.
GetDeviceBuffer
(),
sorted_weights_dev
.
GetDeviceBuffer
(),
...
...
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
View file @
a4140bb9
...
...
@@ -3,7 +3,7 @@
#include "moe_sorting_api.hpp"
float
moe_sorting
(
moe_sorting_trait
t
,
moe_sorting_
k
args
a
,
ck_tile
::
stream_config
s
)
float
moe_sorting
(
moe_sorting_trait
t
,
moe_sorting_args
a
,
ck_tile
::
stream_config
s
)
{
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
{
...
...
@@ -15,14 +15,12 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_kargs a, ck_tile::stream_conf
using
index_t
=
ck_tile
::
index_t
;
using
ms_weight_type
=
float
;
using
ms_problem
=
ck_tile
::
MoeSortingProblem
<
index_t
,
ms_weight_type
>
;
// using ms_pipeline = ck_tile::MoeSortingPipeline<ms_problem>;
using
kernel
=
ck_tile
::
MoeSortingKernel
<
ms_problem
>
;
auto
kargs
=
kernel
::
MakeKargs
(
a
);
const
dim3
grids
=
1
;
const
dim3
blocks
=
ck_tile
::
max
(
t
.
experts
,
ck_tile
::
get_warp_size
());
const
size_t
lds_size
=
((
blocks
.
x
+
1
)
*
t
.
experts
+
(
t
.
experts
+
1
))
*
sizeof
(
index_t
);
float
ave_time
=
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
(
kernel
{},
grids
,
blocks
,
lds_size
,
kargs
));
const
dim3
grids
=
kernel
::
GridSize
(
a
);
const
dim3
blocks
=
kernel
::
BlockSize
(
a
);
float
ave_time
=
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
(
kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
}
return
-
1
;
...
...
example/ck_tile/13_moe_sorting/moe_sorting_api.hpp
View file @
a4140bb9
...
...
@@ -17,8 +17,8 @@ struct moe_sorting_trait
int
tokens
;
};
struct
moe_sorting_
k
args
:
public
ck_tile
::
MoeSortingHostArgs
struct
moe_sorting_args
:
public
ck_tile
::
MoeSortingHostArgs
{
};
float
moe_sorting
(
moe_sorting_trait
t
,
moe_sorting_
k
args
a
,
ck_tile
::
stream_config
s
);
float
moe_sorting
(
moe_sorting_trait
t
,
moe_sorting_args
a
,
ck_tile
::
stream_config
s
);
include/ck_tile/ops/moe_sorting.hpp
View file @
a4140bb9
...
...
@@ -7,4 +7,5 @@
#include "ck_tile/ops/moe_sorting/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/moe_sorting/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/moe_sorting/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/moe_sorting/kernel/moe_sorting_kernel.hpp
View file @
a4140bb9
...
...
@@ -29,7 +29,6 @@ struct MoeSortingHostArgs
template
<
typename
Problem_
>
struct
MoeSortingKernel
{
// using Pipeline = remove_cvref_t<Pipeline_>;
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
IndexType
=
typename
Problem
::
IndexType
;
...
...
@@ -37,10 +36,63 @@ struct MoeSortingKernel
typedef
MoeSortingHostArgs
MoeSortingKargs
;
using
Kargs
=
MoeSortingKargs
;
using
Hargs
=
MoeSortingHostArgs
;
CK_TILE_HOST
static
constexpr
auto
MakeKargs
(
const
Hargs
&
h
)
{
return
h
;
}
struct
Kargs
{
const
void
*
p_topk_ids
;
const
void
*
p_weights
;
void
*
sorted_token_ids
;
void
*
sorted_weights
;
void
*
expert_ids
;
void
*
total_tokens_post_pad
;
index_t
tokens
;
index_t
num_experts
;
index_t
tokens_per_thread
;
mdiv
unit_size_mdiv
;
mdiv
topk_mdiv
;
};
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
)
{
// TODO: assume num-experts not too much
return
dim3
(
1
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
(
const
Hargs
&
h
)
{
// TODO: need pad to multiply of warp size
return
dim3
(
ck_tile
::
max
(
h
.
num_experts
,
ck_tile
::
get_warp_size
()));
}
// in byte
CK_TILE_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
// const auto blocks = BlockSize(h);
// return ((blockDim.x + 1) * k.num_experts + (k.num_experts + 1)) * sizeof(index_t);
// TODO: can not use dynamic calculation. need use static to guide compiler
return
65536
;
}
CK_TILE_HOST
static
constexpr
auto
MakeKargs
(
const
Hargs
&
h
)
{
Kargs
k
;
k
.
p_topk_ids
=
h
.
p_topk_ids
;
k
.
p_weights
=
h
.
p_weights
;
k
.
sorted_token_ids
=
h
.
sorted_token_ids
;
k
.
sorted_weights
=
h
.
sorted_weights
;
k
.
expert_ids
=
h
.
expert_ids
;
k
.
total_tokens_post_pad
=
h
.
total_tokens_post_pad
;
k
.
tokens
=
h
.
tokens
;
k
.
num_experts
=
h
.
num_experts
;
const
auto
blocks
=
BlockSize
(
h
);
k
.
tokens_per_thread
=
integer_divide_ceil
(
h
.
tokens
*
h
.
topk
,
blocks
.
x
);
k
.
unit_size_mdiv
=
mdiv
{
static_cast
<
uint32_t
>
(
h
.
unit_size
)};
k
.
topk_mdiv
=
mdiv
{
static_cast
<
uint32_t
>
(
h
.
topk
)};
return
k
;
}
CK_TILE_DEVICE
index_t
calc_index
(
index_t
total_col
,
index_t
row
,
index_t
col
)
const
{
...
...
@@ -54,15 +106,16 @@ struct MoeSortingKernel
index_t
*
expert_ids
,
index_t
*
total_tokens_post_pad
,
const
index_t
num_experts
,
const
index_t
unit_size
,
const
index_t
tokens_per_thread
,
const
index_t
numel
,
const
index_t
topk
)
const
const
mdiv
unit_size_mdiv
,
const
mdiv
topk_mdiv
,
void
*
smem
)
const
{
const
index_t
tokens_per_thread
=
integer_divide_ceil
(
numel
,
blockDim
.
x
);
const
index_t
tid
=
static_cast
<
index_t
>
(
threadIdx
.
x
);
const
index_t
start_idx
=
tid
*
tokens_per_thread
;
extern
__shared__
index_t
shared_
mem
[]
;
index_t
*
shared_mem
=
reinterpret_cast
<
index_t
*>
(
s
mem
)
;
index_t
*
tokens_cnts
=
shared_mem
;
// 2d: (blockDim.x + 1, num_experts)
index_t
*
cumsum
=
shared_mem
+
(
blockDim
.
x
+
1
)
*
num_experts
;
// 1: (num_experts + 1)
...
...
@@ -88,28 +141,29 @@ struct MoeSortingKernel
}
}
__syncthreads
();
//
__syncthreads();
if
(
tid
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
max
(
integer_divide_ceil
(
tokens_cnts
[
calc_index
(
num_experts
,
blockDim
.
x
,
i
-
1
)],
unit_size
),
1
)
*
unit_size
;
auto
current_units
=
[
&
]()
{
index_t
x_
=
tokens_cnts
[
calc_index
(
num_experts
,
blockDim
.
x
,
i
-
1
)]
+
unit_size_mdiv
.
divisor
-
1
;
index_t
y_
=
unit_size_mdiv
.
div
(
x_
);
return
max
(
y_
,
1
)
*
unit_size_mdiv
.
divisor
;
}();
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
current_units
;
}
*
total_tokens_post_pad
=
cumsum
[
num_experts
]
/
unit_size
;
*
total_tokens_post_pad
=
unit_size_mdiv
.
div
(
cumsum
[
num_experts
]
)
;
}
__syncthreads
();
if
(
tid
<
num_experts
)
{
for
(
int
i
=
cumsum
[
tid
];
i
<
cumsum
[
tid
+
1
];
i
+=
unit_size
)
for
(
int
i
=
cumsum
[
tid
];
i
<
cumsum
[
tid
+
1
];
i
+=
unit_size
_mdiv
.
divisor
)
{
expert_ids
[
i
/
unit_size
]
=
tid
;
expert_ids
[
unit_size
_mdiv
.
div
(
i
)
]
=
tid
;
}
}
...
...
@@ -118,11 +172,12 @@ struct MoeSortingKernel
index_t
expert_id
=
topk_id
[
i
];
index_t
rank_post_pad
=
tokens_cnts
[
calc_index
(
num_experts
,
tid
,
expert_id
)]
+
cumsum
[
expert_id
];
sorted_token_ids
[
rank_post_pad
]
=
i
/
topk
;
sorted_token_ids
[
rank_post_pad
]
=
topk
_mdiv
.
div
(
i
)
;
sorted_weights
[
rank_post_pad
]
=
weights
[
i
];
++
tokens_cnts
[
calc_index
(
num_experts
,
tid
,
expert_id
)];
}
const
index_t
prefill_token
=
numel
/
topk
;
const
index_t
prefill_token
=
topk_mdiv
.
div
(
numel
);
if
(
tid
<
num_experts
)
{
index_t
expert_offset
=
...
...
@@ -138,7 +193,8 @@ struct MoeSortingKernel
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
const
size_t
numel
=
kargs
.
tokens
*
kargs
.
topk
;
const
size_t
numel
=
kargs
.
tokens
*
kargs
.
topk_mdiv
.
divisor
;
__shared__
char
smem
[
GetSmemSize
()];
return
moe_align_block_size_kernel
(
static_cast
<
const
IndexType
*>
(
kargs
.
p_topk_ids
),
static_cast
<
const
WeightType
*>
(
kargs
.
p_weights
),
static_cast
<
IndexType
*>
(
kargs
.
sorted_token_ids
),
...
...
@@ -146,9 +202,11 @@ struct MoeSortingKernel
static_cast
<
IndexType
*>
(
kargs
.
expert_ids
),
static_cast
<
IndexType
*>
(
kargs
.
total_tokens_post_pad
),
kargs
.
num_experts
,
kargs
.
unit_size
,
kargs
.
tokens_per_thread
,
numel
,
kargs
.
topk
);
kargs
.
unit_size_mdiv
,
kargs
.
topk_mdiv
,
smem
);
}
};
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment