Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
bfe0120a
Commit
bfe0120a
authored
Nov 05, 2024
by
dummycoderfe
Browse files
add extblocksnel to set zeros for moebufs
parent
8f4dc357
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
49 additions
and
17 deletions
+49
-17
example/ck_tile/13_moe_sorting/moe_sorting.cpp
example/ck_tile/13_moe_sorting/moe_sorting.cpp
+15
-1
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
+8
-7
include/ck_tile/ops/moe_sorting/kernel/moe_sorting_kernel.hpp
...ude/ck_tile/ops/moe_sorting/kernel/moe_sorting_kernel.hpp
+26
-9
No files found.
example/ck_tile/13_moe_sorting/moe_sorting.cpp
View file @
bfe0120a
...
@@ -25,6 +25,7 @@ auto create_args(int argc, char* argv[])
...
@@ -25,6 +25,7 @@ auto create_args(int argc, char* argv[])
.
insert
(
"e"
,
"8"
,
"number of experts"
)
.
insert
(
"e"
,
"8"
,
"number of experts"
)
.
insert
(
"k"
,
"4"
,
"topk"
)
.
insert
(
"k"
,
"4"
,
"topk"
)
.
insert
(
"unit"
,
"32"
,
"unit_size"
)
.
insert
(
"unit"
,
"32"
,
"unit_size"
)
.
insert
(
"moe_buf_size"
,
"-1"
,
"moe_buf_size"
)
.
insert
(
"seed"
,
"-1"
,
"seed to be used, -1 means random every time"
)
.
insert
(
"seed"
,
"-1"
,
"seed to be used, -1 means random every time"
)
.
insert
(
"kname"
,
"0"
,
"when set to 1 it will print kernel name"
)
.
insert
(
"kname"
,
"0"
,
"when set to 1 it will print kernel name"
)
.
insert
(
"warmup"
,
"5"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"warmup"
,
"5"
,
"number of iterations before benchmark the kernel"
)
...
@@ -69,6 +70,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -69,6 +70,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
int
topk
=
args
.
get_int
(
"k"
);
int
topk
=
args
.
get_int
(
"k"
);
int
seed
=
args
.
get_int
(
"seed"
);
int
seed
=
args
.
get_int
(
"seed"
);
int
unit_size
=
args
.
get_int
(
"unit"
);
int
unit_size
=
args
.
get_int
(
"unit"
);
int
moe_buf_size
=
args
.
get_int
(
"moe_buf_size"
);
int
kname
=
args
.
get_int
(
"kname"
);
int
kname
=
args
.
get_int
(
"kname"
);
int
warmup
=
args
.
get_int
(
"warmup"
);
int
warmup
=
args
.
get_int
(
"warmup"
);
int
repeat
=
args
.
get_int
(
"repeat"
);
int
repeat
=
args
.
get_int
(
"repeat"
);
...
@@ -94,8 +96,10 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -94,8 +96,10 @@ bool test_moe_sorting(ck_tile::ArgParser args)
ck_tile
::
HostTensor
<
WeightType
>
sorted_weights_host
({
max_output_ids
},
{
1
});
ck_tile
::
HostTensor
<
WeightType
>
sorted_weights_host
({
max_output_ids
},
{
1
});
ck_tile
::
HostTensor
<
IndexType
>
expert_ids_host
({
max_output_ids
/
unit_size
},
{
1
});
ck_tile
::
HostTensor
<
IndexType
>
expert_ids_host
({
max_output_ids
/
unit_size
},
{
1
});
ck_tile
::
HostTensor
<
IndexType
>
sorted_id_cnt_host
({
1
},
{
1
});
ck_tile
::
HostTensor
<
IndexType
>
sorted_id_cnt_host
({
1
},
{
1
});
ck_tile
::
HostTensor
<
IndexType
>
moe_buf_host
({
moe_buf_size
},
{
1
});
ck_tile
::
FillUniformDistribution
<
WeightType
>
{
-
.5
f
,
.5
f
}(
weights_host
);
ck_tile
::
FillUniformDistribution
<
WeightType
>
{
-
.5
f
,
.5
f
}(
weights_host
);
ck_tile
::
FillUniformDistribution
<
WeightType
>
{
-
.5
f
,
.5
f
}(
moe_buf_host
);
topid_unique_gen
<
IndexType
>
(
topk_ids_host
.
mData
,
tokens
,
topk
,
experts
,
seed
);
topid_unique_gen
<
IndexType
>
(
topk_ids_host
.
mData
,
tokens
,
topk
,
experts
,
seed
);
ck_tile
::
DeviceMem
topk_ids_dev
(
topk_ids_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
topk_ids_dev
(
topk_ids_host
.
get_element_space_size_in_bytes
());
...
@@ -104,9 +108,11 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -104,9 +108,11 @@ bool test_moe_sorting(ck_tile::ArgParser args)
ck_tile
::
DeviceMem
sorted_weights_dev
(
sorted_weights_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sorted_weights_dev
(
sorted_weights_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
expert_ids_dev
(
expert_ids_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
expert_ids_dev
(
expert_ids_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sorted_id_cnt_dev
(
sorted_id_cnt_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sorted_id_cnt_dev
(
sorted_id_cnt_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
moe_buf_dev
(
moe_buf_host
.
get_element_space_size_in_bytes
());
topk_ids_dev
.
ToDevice
(
topk_ids_host
.
data
());
topk_ids_dev
.
ToDevice
(
topk_ids_host
.
data
());
weights_dev
.
ToDevice
(
weights_host
.
data
());
weights_dev
.
ToDevice
(
weights_host
.
data
());
moe_buf_dev
.
ToDevice
(
moe_buf_host
.
data
());
moe_sorting_trait
trait
{
index_prec
,
weight_prec
,
experts
,
topk
,
unit_size
,
tokens
};
moe_sorting_trait
trait
{
index_prec
,
weight_prec
,
experts
,
topk
,
unit_size
,
tokens
};
...
@@ -116,10 +122,12 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -116,10 +122,12 @@ bool test_moe_sorting(ck_tile::ArgParser args)
sorted_weights_dev
.
GetDeviceBuffer
(),
sorted_weights_dev
.
GetDeviceBuffer
(),
expert_ids_dev
.
GetDeviceBuffer
(),
expert_ids_dev
.
GetDeviceBuffer
(),
sorted_id_cnt_dev
.
GetDeviceBuffer
(),
sorted_id_cnt_dev
.
GetDeviceBuffer
(),
moe_buf_dev
.
GetDeviceBuffer
(),
tokens
,
tokens
,
unit_size
,
unit_size
,
experts
,
experts
,
topk
};
topk
,
moe_buf_size
};
ck_tile
::
stream_config
sc
{
nullptr
,
ck_tile
::
stream_config
sc
{
nullptr
,
true
,
true
,
...
@@ -146,6 +154,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -146,6 +154,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
sorted_weights_dev
.
FromDevice
(
sorted_weights_host
.
data
());
sorted_weights_dev
.
FromDevice
(
sorted_weights_host
.
data
());
expert_ids_dev
.
FromDevice
(
expert_ids_host
.
data
());
expert_ids_dev
.
FromDevice
(
expert_ids_host
.
data
());
sorted_id_cnt_dev
.
FromDevice
(
sorted_id_cnt_host
.
data
());
sorted_id_cnt_dev
.
FromDevice
(
sorted_id_cnt_host
.
data
());
moe_buf_dev
.
FromDevice
(
moe_buf_host
.
data
());
bool
rtn
=
true
;
bool
rtn
=
true
;
if
(
validate
)
if
(
validate
)
...
@@ -153,6 +162,9 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -153,6 +162,9 @@ bool test_moe_sorting(ck_tile::ArgParser args)
ck_tile
::
HostTensor
<
IndexType
>
sorted_ids_ref
({
max_output_ids
},
{
1
});
ck_tile
::
HostTensor
<
IndexType
>
sorted_ids_ref
({
max_output_ids
},
{
1
});
ck_tile
::
HostTensor
<
WeightType
>
sorted_weights_ref
({
max_output_ids
},
{
1
});
ck_tile
::
HostTensor
<
WeightType
>
sorted_weights_ref
({
max_output_ids
},
{
1
});
ck_tile
::
HostTensor
<
IndexType
>
expert_ids_ref
({
max_output_ids
/
unit_size
},
{
1
});
ck_tile
::
HostTensor
<
IndexType
>
expert_ids_ref
({
max_output_ids
/
unit_size
},
{
1
});
ck_tile
::
HostTensor
<
IndexType
>
moe_buf_ref
({
moe_buf_size
},
{
1
});
moe_buf_ref
.
SetZero
();
int32_t
total_tokens_post_pad
=
0
;
int32_t
total_tokens_post_pad
=
0
;
ck_tile
::
reference_moe_sorting
<
WeightType
,
IndexType
>
(
topk_ids_host
,
ck_tile
::
reference_moe_sorting
<
WeightType
,
IndexType
>
(
topk_ids_host
,
weights_host
,
weights_host
,
...
@@ -171,6 +183,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -171,6 +183,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
1e-6
);
1e-6
);
rtn
&=
ck_tile
::
check_err
(
rtn
&=
ck_tile
::
check_err
(
expert_ids_host
,
expert_ids_ref
,
std
::
string
(
"OUT Error: Incorrect eid!"
),
1e-6
,
1e-6
);
expert_ids_host
,
expert_ids_ref
,
std
::
string
(
"OUT Error: Incorrect eid!"
),
1e-6
,
1e-6
);
rtn
&=
ck_tile
::
check_err
(
moe_buf_host
,
moe_buf_ref
,
std
::
string
(
"OUT Error: Incorrect zero buf!"
),
0
,
0
);
rtn
&=
total_tokens_post_pad
==
sorted_id_cnt_host
.
mData
[
0
];
rtn
&=
total_tokens_post_pad
==
sorted_id_cnt_host
.
mData
[
0
];
}
}
...
...
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
View file @
bfe0120a
...
@@ -5,13 +5,14 @@
...
@@ -5,13 +5,14 @@
#define MOE_SORTING_DISPATCH(unroll_num_) \
#define MOE_SORTING_DISPATCH(unroll_num_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
float ave_time = \
const auto lds_bytes = kernel::GetSmemSize(a); \
ck_tile::launch_kernel(s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
return ave_time;
float
moe_sorting
(
moe_sorting_trait
t
,
moe_sorting_args
a
,
ck_tile
::
stream_config
s
)
float
moe_sorting
(
moe_sorting_trait
t
,
moe_sorting_args
a
,
ck_tile
::
stream_config
s
)
...
...
include/ck_tile/ops/moe_sorting/kernel/moe_sorting_kernel.hpp
View file @
bfe0120a
...
@@ -20,10 +20,12 @@ struct MoeSortingHostArgs
...
@@ -20,10 +20,12 @@ struct MoeSortingHostArgs
void
*
sorted_weights
;
void
*
sorted_weights
;
void
*
expert_ids
;
void
*
expert_ids
;
void
*
total_tokens_post_pad
;
void
*
total_tokens_post_pad
;
void
*
moe_buf
;
index_t
tokens
;
index_t
tokens
;
index_t
unit_size
;
index_t
unit_size
;
index_t
num_experts
;
index_t
num_experts
;
index_t
topk
;
index_t
topk
;
index_t
moe_buf_set_bytes
;
};
};
template
<
typename
Problem_
>
template
<
typename
Problem_
>
...
@@ -46,33 +48,32 @@ struct MoeSortingKernel
...
@@ -46,33 +48,32 @@ struct MoeSortingKernel
void
*
sorted_weights
;
void
*
sorted_weights
;
void
*
expert_ids
;
void
*
expert_ids
;
void
*
total_tokens_post_pad
;
void
*
total_tokens_post_pad
;
void
*
moe_buf
;
index_t
tokens
;
index_t
tokens
;
index_t
num_experts
;
index_t
num_experts
;
index_t
moe_buf_set_bytes
;
index_t
tokens_per_thread
;
index_t
tokens_per_thread
;
mdiv
unit_size_mdiv
;
mdiv
unit_size_mdiv
;
mdiv
topk_mdiv
;
mdiv
topk_mdiv
;
};
};
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
h
)
{
{
// TODO: assume num-experts not too much
// TODO: assume num-experts not too much
return
dim3
(
1
);
return
dim3
(
1
+
ck_tile
::
integer_divide_ceil
(
h
.
moe_buf_set_bytes
,
BlockSize
(
h
).
x
*
16
)
);
}
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
(
const
Hargs
&
h
)
CK_TILE_HOST
static
constexpr
auto
BlockSize
(
const
Hargs
&
h
)
{
{
// TODO: need pad to multiply of warp size
return
dim3
(
ck_tile
::
integer_least_multiple
(
h
.
num_experts
,
ck_tile
::
get_warp_size
()));
return
dim3
(
ck_tile
::
integer_least_multiple
(
h
.
num_experts
,
ck_tile
::
get_warp_size
()));
}
}
// in byte
// in byte
CK_TILE_
DEVICE
static
constexpr
index_t
GetSmemSize
()
CK_TILE_
HOST
static
constexpr
auto
GetSmemSize
(
const
Hargs
&
h
)
{
{
// const auto blocks = BlockSize(h);
const
auto
blocks
=
BlockSize
(
h
);
// return ((blockDim.x + 1) * k.num_experts + (k.num_experts + 1)) * sizeof(index_t);
return
((
blocks
.
x
+
1
)
*
h
.
num_experts
+
(
h
.
num_experts
+
1
))
*
sizeof
(
index_t
);
// TODO: can not use dynamic calculation. need use static to guide compiler
return
65536
;
}
}
CK_TILE_HOST
static
constexpr
auto
MakeKargs
(
const
Hargs
&
h
)
CK_TILE_HOST
static
constexpr
auto
MakeKargs
(
const
Hargs
&
h
)
...
@@ -83,9 +84,11 @@ struct MoeSortingKernel
...
@@ -83,9 +84,11 @@ struct MoeSortingKernel
k
.
sorted_token_ids
=
h
.
sorted_token_ids
;
k
.
sorted_token_ids
=
h
.
sorted_token_ids
;
k
.
sorted_weights
=
h
.
sorted_weights
;
k
.
sorted_weights
=
h
.
sorted_weights
;
k
.
expert_ids
=
h
.
expert_ids
;
k
.
expert_ids
=
h
.
expert_ids
;
k
.
moe_buf
=
h
.
moe_buf
;
k
.
total_tokens_post_pad
=
h
.
total_tokens_post_pad
;
k
.
total_tokens_post_pad
=
h
.
total_tokens_post_pad
;
k
.
tokens
=
h
.
tokens
;
k
.
tokens
=
h
.
tokens
;
k
.
num_experts
=
h
.
num_experts
;
k
.
num_experts
=
h
.
num_experts
;
k
.
moe_buf_set_bytes
=
h
.
moe_buf_set_bytes
;
const
auto
blocks
=
BlockSize
(
h
);
const
auto
blocks
=
BlockSize
(
h
);
k
.
tokens_per_thread
=
integer_divide_ceil
(
h
.
tokens
*
h
.
topk
,
blocks
.
x
);
k
.
tokens_per_thread
=
integer_divide_ceil
(
h
.
tokens
*
h
.
topk
,
blocks
.
x
);
...
@@ -99,6 +102,15 @@ struct MoeSortingKernel
...
@@ -99,6 +102,15 @@ struct MoeSortingKernel
return
row
*
total_col
+
col
;
return
row
*
total_col
+
col
;
}
}
CK_TILE_DEVICE
void
moe_buf_set_zero_kernel
(
uint8x16_t
*
buf
,
index_t
buf_bytes
)
const
{
const
index_t
offset
=
(
blockIdx
.
x
-
1
)
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
<
buf_bytes
/
16
)
{
buf
[
offset
]
=
uint8x16_t
(
0
);
}
}
CK_TILE_DEVICE
void
moe_align_block_size_kernel
(
const
IndexType
*
__restrict__
topk_id
,
CK_TILE_DEVICE
void
moe_align_block_size_kernel
(
const
IndexType
*
__restrict__
topk_id
,
const
WeightType
*
__restrict__
weights
,
const
WeightType
*
__restrict__
weights
,
index_t
*
sorted_token_ids
,
index_t
*
sorted_token_ids
,
...
@@ -192,8 +204,13 @@ struct MoeSortingKernel
...
@@ -192,8 +204,13 @@ struct MoeSortingKernel
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
{
if
(
blockIdx
.
x
>
0
)
{
return
moe_buf_set_zero_kernel
(
reinterpret_cast
<
uint8x16_t
*>
(
kargs
.
moe_buf
),
kargs
.
moe_buf_set_bytes
);
}
const
size_t
numel
=
kargs
.
tokens
*
kargs
.
topk_mdiv
.
divisor
;
const
size_t
numel
=
kargs
.
tokens
*
kargs
.
topk_mdiv
.
divisor
;
__shared__
char
smem
[
GetSmemSize
()
];
extern
__shared__
char
smem
[];
return
moe_align_block_size_kernel
(
static_cast
<
const
IndexType
*>
(
kargs
.
p_topk_ids
),
return
moe_align_block_size_kernel
(
static_cast
<
const
IndexType
*>
(
kargs
.
p_topk_ids
),
static_cast
<
const
WeightType
*>
(
kargs
.
p_weights
),
static_cast
<
const
WeightType
*>
(
kargs
.
p_weights
),
static_cast
<
IndexType
*>
(
kargs
.
sorted_token_ids
),
static_cast
<
IndexType
*>
(
kargs
.
sorted_token_ids
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment