Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9964919d
Commit
9964919d
authored
Nov 01, 2024
by
dummycoderfe
Browse files
fix comments & typo
parent
2bf0057a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
67 additions
and
72 deletions
+67
-72
example/ck_tile/12_moe_sorting/moe_sorting.cpp
example/ck_tile/12_moe_sorting/moe_sorting.cpp
+9
-16
example/ck_tile/12_moe_sorting/moe_sorting_api.cpp
example/ck_tile/12_moe_sorting/moe_sorting_api.cpp
+2
-2
include/ck_tile/host/reference/reference_moe_sorting.hpp
include/ck_tile/host/reference/reference_moe_sorting.hpp
+26
-26
include/ck_tile/ops/moe_sorting/kernel/moe_sorting_kernel.hpp
...ude/ck_tile/ops/moe_sorting/kernel/moe_sorting_kernel.hpp
+9
-7
include/ck_tile/ops/moe_sorting/pipeline/moe_sorting_pipeline.hpp
...ck_tile/ops/moe_sorting/pipeline/moe_sorting_pipeline.hpp
+21
-21
No files found.
example/ck_tile/12_moe_sorting/moe_sorting.cpp
View file @
9964919d
...
@@ -158,34 +158,27 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -158,34 +158,27 @@ bool test_moe_sorting(ck_tile::ArgParser args)
bool
rtn
=
true
;
bool
rtn
=
true
;
if
(
validate
)
if
(
validate
)
{
{
ck_tile
::
HostTensor
<
IndexType
>
sorted_ids_ref
({
max_output_ids
},
{
1
});
ck_tile
::
HostTensor
<
IndexType
>
sorted_ids_ref
({
max_output_ids
},
{
1
});
ck_tile
::
HostTensor
<
WeightType
>
sorted_weights_ref
({
max_output_ids
},
{
1
});
ck_tile
::
HostTensor
<
WeightType
>
sorted_weights_ref
({
max_output_ids
},
{
1
});
ck_tile
::
HostTensor
<
IndexType
>
expert_ids_ref
({
max_output_ids
/
unit_size
},
{
1
});
ck_tile
::
HostTensor
<
IndexType
>
expert_ids_ref
({
max_output_ids
/
unit_size
},
{
1
});
int32_t
total_tokens_post_pad
=
0
;
int32_t
total_tokens_post_pad
=
0
;
ck_tile
::
reference_moe_sorting
<
WeightType
,
IndexType
>
(
sorted_ids_ref
.
data
(),
ck_tile
::
reference_moe_sorting
<
WeightType
,
IndexType
>
(
topk_ids_host
,
sorted_weights_ref
.
data
(),
weights_host
,
expert_ids_ref
.
data
(),
sorted_ids_ref
,
sorted_weights_ref
,
expert_ids_ref
,
total_tokens_post_pad
,
total_tokens_post_pad
,
weights_host
.
data
(),
topk_ids_host
.
data
(),
topk_ids_host
.
size
()
/
topk
,
experts
,
experts
,
topk
,
unit_size
);
unit_size
);
float
atol
=
1e-6
;
float
rtol
=
1e-6
;
rtn
&=
ck_tile
::
check_err
(
rtn
&=
ck_tile
::
check_err
(
sorted_ids_host
,
sorted_ids_ref
,
std
::
string
(
"OUT Error: Incorrect ids!"
),
rtol
,
atol
);
sorted_ids_host
,
sorted_ids_ref
,
std
::
string
(
"OUT Error: Incorrect ids!"
),
1e-6
,
1e-6
);
rtn
&=
ck_tile
::
check_err
(
sorted_weights_host
,
rtn
&=
ck_tile
::
check_err
(
sorted_weights_host
,
sorted_weights_ref
,
sorted_weights_ref
,
std
::
string
(
"OUT Error: Incorrect w!"
),
std
::
string
(
"OUT Error: Incorrect w!"
),
rtol
,
1e-6
,
atol
);
1e-6
);
rtn
&=
ck_tile
::
check_err
(
rtn
&=
ck_tile
::
check_err
(
expert_ids_host
,
expert_ids_ref
,
std
::
string
(
"OUT Error: Incorrect eid!"
),
rtol
,
atol
);
expert_ids_host
,
expert_ids_ref
,
std
::
string
(
"OUT Error: Incorrect eid!"
),
1e-6
,
1e-6
);
rtn
&=
total_tokens_post_pad
==
sorted_id_cnt_host
.
mData
[
0
];
rtn
&=
total_tokens_post_pad
==
sorted_id_cnt_host
.
mData
[
0
];
}
}
...
...
example/ck_tile/12_moe_sorting/moe_sorting_api.cpp
View file @
9964919d
...
@@ -10,8 +10,8 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_kargs a, ck_tile::stream_conf
...
@@ -10,8 +10,8 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_kargs a, ck_tile::stream_conf
using
index_t
=
ck_tile
::
index_t
;
using
index_t
=
ck_tile
::
index_t
;
using
ms_weight_type
=
float
;
using
ms_weight_type
=
float
;
using
ms_problem
=
ck_tile
::
MoeSortingProblem
<
index_t
,
ms_weight_type
>
;
using
ms_problem
=
ck_tile
::
MoeSortingProblem
<
index_t
,
ms_weight_type
>
;
using
ms_pipeline
=
ck_tile
::
MoeSortingPipeline
<
ms_problem
>
;
//
using ms_pipeline = ck_tile::MoeSortingPipeline<ms_problem>;
using
kernel
=
ck_tile
::
MoeSortingKernel
<
ms_p
ipeline
>
;
using
kernel
=
ck_tile
::
MoeSortingKernel
<
ms_p
roblem
>
;
auto
kargs
=
kernel
::
MakeKargs
(
a
);
auto
kargs
=
kernel
::
MakeKargs
(
a
);
const
dim3
grids
=
1
;
const
dim3
grids
=
1
;
const
dim3
blocks
=
ck_tile
::
max
(
t
.
experts
,
ck_tile
::
get_warp_size
());
const
dim3
blocks
=
ck_tile
::
max
(
t
.
experts
,
ck_tile
::
get_warp_size
());
...
...
include/ck_tile/host/reference/reference_moe_sorting.hpp
View file @
9964919d
...
@@ -9,21 +9,21 @@
...
@@ -9,21 +9,21 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
WeightType
,
typename
IndexType
=
index_t
>
template
<
typename
WeightType
,
typename
IndexType
=
index_t
>
CK_TILE_HOST
void
reference_moe_sorting
(
IndexType
*
sorted_token_ids_ptr
,
CK_TILE_HOST
void
reference_moe_sorting
(
const
HostTensor
<
IndexType
>&
topk_ids
,
WeightType
*
sorted_weight_buf
,
const
HostTensor
<
WeightType
>&
weights
,
IndexType
*
sorted_expert_ids_ptr
,
HostTensor
<
IndexType
>&
sorted_token_ids
,
index_t
&
sub_x_cnt
,
HostTensor
<
WeightType
>&
sorted_weight
,
const
WeightType
*
weights_ptr
,
HostTensor
<
IndexType
>&
sorted_expert_ids
,
const
IndexType
*
topk_ids_ptr
,
index_t
&
unit_cnt
,
const
index_t
num_token
,
const
index_t
experts
,
const
index_t
experts
,
const
index_t
topk
,
const
index_t
unit_size
)
const
index_t
sub_x
)
{
{
const
index_t
num_token
=
topk_ids
.
mDesc
.
get_lengths
()[
0
];
const
index_t
topk
=
topk_ids
.
mDesc
.
get_lengths
()[
1
];
std
::
vector
<
std
::
vector
<
IndexType
>>
expert_tokens
(
experts
,
std
::
vector
<
std
::
vector
<
IndexType
>>
expert_tokens
(
experts
,
std
::
vector
<
IndexType
>
(
sub_x
,
num_token
));
std
::
vector
<
IndexType
>
(
unit_size
,
num_token
));
std
::
vector
<
std
::
vector
<
WeightType
>>
expert_token_weights
(
experts
,
std
::
vector
<
std
::
vector
<
WeightType
>>
expert_token_weights
(
experts
,
std
::
vector
<
WeightType
>
(
sub_x
,
0
));
std
::
vector
<
WeightType
>
(
unit_size
,
0
));
std
::
vector
<
IndexType
>
expert_slices
(
experts
,
1
);
std
::
vector
<
IndexType
>
expert_slices
(
experts
,
1
);
std
::
vector
<
IndexType
>
expert_slice_idxs
(
experts
,
0
);
std
::
vector
<
IndexType
>
expert_slice_idxs
(
experts
,
0
);
...
@@ -31,16 +31,16 @@ CK_TILE_HOST void reference_moe_sorting(IndexType* sorted_token_ids_ptr,
...
@@ -31,16 +31,16 @@ CK_TILE_HOST void reference_moe_sorting(IndexType* sorted_token_ids_ptr,
{
{
for
(
index_t
k
=
0
;
k
<
topk
;
k
++
)
for
(
index_t
k
=
0
;
k
<
topk
;
k
++
)
{
{
i
ndex
_t
e
=
*
(
topk_ids
_ptr
+
t
*
topk
+
k
);
I
ndex
Type
e
=
topk_ids
(
t
,
k
);
WeightType
w
=
*
(
weights
_ptr
+
t
*
topk
+
k
);
WeightType
w
=
weights
(
t
,
k
);
index_t
idx
=
expert_slice_idxs
[
e
];
index_t
idx
=
expert_slice_idxs
[
e
];
if
(
idx
>
expert_slices
[
e
]
*
sub_x
-
1
)
if
(
idx
>
expert_slices
[
e
]
*
unit_size
-
1
)
{
{
expert_slices
[
e
]
++
;
expert_slices
[
e
]
++
;
index_t
new_size
=
expert_slices
[
e
]
*
sub_x
;
index_t
new_size
=
expert_slices
[
e
]
*
unit_size
;
expert_tokens
[
e
].
resize
(
new_size
);
expert_tokens
[
e
].
resize
(
new_size
);
expert_token_weights
[
e
].
resize
(
new_size
);
expert_token_weights
[
e
].
resize
(
new_size
);
for
(
index_t
idx
=
(
expert_slices
[
e
]
-
1
)
*
sub_x
;
idx
<
new_size
;
idx
++
)
for
(
index_t
idx
=
(
expert_slices
[
e
]
-
1
)
*
unit_size
;
idx
<
new_size
;
idx
++
)
{
{
expert_tokens
[
e
][
idx
]
=
num_token
;
expert_tokens
[
e
][
idx
]
=
num_token
;
expert_token_weights
[
e
][
idx
]
=
0
;
expert_token_weights
[
e
][
idx
]
=
0
;
...
@@ -53,23 +53,23 @@ CK_TILE_HOST void reference_moe_sorting(IndexType* sorted_token_ids_ptr,
...
@@ -53,23 +53,23 @@ CK_TILE_HOST void reference_moe_sorting(IndexType* sorted_token_ids_ptr,
}
}
}
}
IndexType
*
tokens
=
sorted_token_ids
_ptr
;
IndexType
*
out_
tokens
=
sorted_token_ids
.
data
()
;
WeightType
*
weights
=
sorted_weight
_buf
;
WeightType
*
out_
weights
=
sorted_weight
.
data
()
;
IndexType
*
erp
_id
s
=
sorted_expert_ids
_ptr
;
IndexType
*
out_expert
_id
=
sorted_expert_ids
.
data
()
;
for
(
index_t
e
=
0
;
e
<
experts
;
e
++
)
for
(
index_t
e
=
0
;
e
<
experts
;
e
++
)
{
{
memcpy
(
tokens
,
expert_tokens
[
e
].
data
(),
sizeof
(
index_t
)
*
expert_slices
[
e
]
*
sub_x
);
memcpy
(
out_
tokens
,
expert_tokens
[
e
].
data
(),
sizeof
(
index_t
)
*
expert_slices
[
e
]
*
unit_size
);
tokens
+=
expert_slices
[
e
]
*
sub_x
;
out_
tokens
+=
expert_slices
[
e
]
*
unit_size
;
memcpy
(
memcpy
(
weights
,
expert_token_weights
[
e
].
data
(),
sizeof
(
WeightType
)
*
expert_slices
[
e
]
*
sub_x
);
out_
weights
,
expert_token_weights
[
e
].
data
(),
sizeof
(
WeightType
)
*
expert_slices
[
e
]
*
unit_size
);
weights
+=
expert_slices
[
e
]
*
sub_x
;
out_
weights
+=
expert_slices
[
e
]
*
unit_size
;
for
(
index_t
s
=
0
;
s
<
expert_slices
[
e
];
s
++
)
for
(
index_t
s
=
0
;
s
<
expert_slices
[
e
];
s
++
)
{
{
erp
_id
s
[
s
]
=
e
;
out_expert
_id
[
s
]
=
e
;
sub_x
_cnt
++
;
unit
_cnt
++
;
}
}
erp
_id
s
+=
expert_slices
[
e
];
out_expert
_id
+=
expert_slices
[
e
];
}
}
return
;
return
;
...
...
include/ck_tile/ops/moe_sorting/kernel/moe_sorting_kernel.hpp
View file @
9964919d
...
@@ -26,11 +26,11 @@ struct MoeSortingHostArgs
...
@@ -26,11 +26,11 @@ struct MoeSortingHostArgs
index_t
topk
;
index_t
topk
;
};
};
template
<
typename
P
ipeline
_
>
template
<
typename
P
roblem
_
>
struct
MoeSortingKernel
struct
MoeSortingKernel
{
{
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
//
using Pipeline = remove_cvref_t<Pipeline_>;
using
Problem
=
remove_cvref_t
<
typename
Pipeline
::
Problem
>
;
using
Problem
=
remove_cvref_t
<
Problem
_
>
;
using
IndexType
=
typename
Problem
::
IndexType
;
using
IndexType
=
typename
Problem
::
IndexType
;
using
WeightType
=
typename
Problem
::
WeightType
;
using
WeightType
=
typename
Problem
::
WeightType
;
...
@@ -47,8 +47,6 @@ struct MoeSortingKernel
...
@@ -47,8 +47,6 @@ struct MoeSortingKernel
return
row
*
total_col
+
col
;
return
row
*
total_col
+
col
;
}
}
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
#define MAX(x, y) (((x) > (y)) ? (x) : (y))
CK_TILE_DEVICE
void
moe_align_block_size_kernel
(
const
IndexType
*
__restrict__
topk_id
,
CK_TILE_DEVICE
void
moe_align_block_size_kernel
(
const
IndexType
*
__restrict__
topk_id
,
const
WeightType
*
__restrict__
weights
,
const
WeightType
*
__restrict__
weights
,
index_t
*
sorted_token_ids
,
index_t
*
sorted_token_ids
,
...
@@ -60,7 +58,7 @@ struct MoeSortingKernel
...
@@ -60,7 +58,7 @@ struct MoeSortingKernel
const
size_t
numel
,
const
size_t
numel
,
const
index_t
topk
)
const
const
index_t
topk
)
const
{
{
const
size_t
tokens_per_thread
=
CEILDIV
(
numel
,
blockDim
.
x
);
const
size_t
tokens_per_thread
=
integer_divide_ceil
(
numel
,
blockDim
.
x
);
const
size_t
start_idx
=
threadIdx
.
x
*
tokens_per_thread
;
const
size_t
start_idx
=
threadIdx
.
x
*
tokens_per_thread
;
extern
__shared__
index_t
shared_mem
[];
extern
__shared__
index_t
shared_mem
[];
...
@@ -73,6 +71,10 @@ struct MoeSortingKernel
...
@@ -73,6 +71,10 @@ struct MoeSortingKernel
tokens_cnts
[
calc_index
(
num_experts
,
threadIdx
.
x
+
1
,
i
)]
=
0
;
tokens_cnts
[
calc_index
(
num_experts
,
threadIdx
.
x
+
1
,
i
)]
=
0
;
}
}
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
++
tokens_cnts
[
calc_index
(
num_experts
,
threadIdx
.
x
+
1
,
topk_id
[
i
])];
}
__syncthreads
();
__syncthreads
();
if
(
threadIdx
.
x
<
num_experts
)
if
(
threadIdx
.
x
<
num_experts
)
...
@@ -93,7 +95,7 @@ struct MoeSortingKernel
...
@@ -93,7 +95,7 @@ struct MoeSortingKernel
{
{
cumsum
[
i
]
=
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
cumsum
[
i
-
1
]
+
MAX
(
CEILDIV
(
tokens_cnts
[
calc_index
(
num_experts
,
blockDim
.
x
,
i
-
1
)],
unit_size
),
max
(
integer_divide_ceil
(
tokens_cnts
[
calc_index
(
num_experts
,
blockDim
.
x
,
i
-
1
)],
unit_size
),
1
)
*
1
)
*
unit_size
;
unit_size
;
}
}
...
...
include/ck_tile/ops/moe_sorting/pipeline/moe_sorting_pipeline.hpp
View file @
9964919d
...
@@ -14,26 +14,26 @@
...
@@ -14,26 +14,26 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
MoeSortingPolicy
>
//
template <typename Problem_, typename Policy_ = MoeSortingPolicy>
struct
MoeSortingPipeline
//
struct MoeSortingPipeline
{
//
{
// TODO: this kernel only support warp per row
//
// TODO: this kernel only support warp per row
using
Problem
=
remove_cvref_t
<
Problem_
>
;
//
using Problem = remove_cvref_t<Problem_>;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
//
using Policy = remove_cvref_t<Policy_>;
using
WeightType
=
typename
Problem
::
WeightType
;
//
using WeightType = typename Problem::WeightType;
//
template <typename TopkIdWindow, typename WeightWindow>
//
template <typename TopkIdWindow, typename WeightWindow>
//
CK_TILE_DEVICE auto operator()(const TopkIdWindow& topk_id_window,
//
CK_TILE_DEVICE auto operator()(const TopkIdWindow& topk_id_window,
//
const WeightWindow& weight_window,
//
const WeightWindow& weight_window,
//
index_t* sorted_token_ids,
//
index_t* sorted_token_ids,
//
WeightType* sorted_weights,
//
WeightType* sorted_weights,
//
index_t* expert_ids,
//
index_t* expert_ids,
//
index_t* total_tokens_post_pad,
//
index_t* total_tokens_post_pad,
//
const index_t num_experts,
//
const index_t num_experts,
//
const index_t unit_size,
//
const index_t unit_size,
//
const size_t numel,
//
const size_t numel,
//
const index_t topk)
//
const index_t topk)
//
{
//
{
//
}
//
}
};
//
};
}
// namespace ck_tile
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment