Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e15c6f2d
Commit
e15c6f2d
authored
Jan 15, 2025
by
coderfeli
Browse files
skip out of bound rowid
parent
20ac5ef9
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
8 deletions
+19
-8
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+4
-4
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
...ude/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
+2
-2
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
...s/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
+13
-2
No files found.
example/ck_tile/15_fused_moe/main.cpp
View file @
e15c6f2d
...
@@ -346,10 +346,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -346,10 +346,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
else
else
{
{
//
for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++) {
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
topk_ids_host
.
mData
.
size
());
i
++
)
{
//
topk_ids_host.mData[i] =
0
;
topk_ids_host
.
mData
[
i
]
=
i
%
4
;
//
}
}
topid_unique_gen
<
IndexDataType
>
(
topk_ids_host
.
mData
,
tokens
,
topk
,
experts
,
11913
);
//
topid_unique_gen<IndexDataType>(topk_ids_host.mData, tokens, topk, experts, 11913);
}
}
// leave it here for future debug purpose
// leave it here for future debug purpose
...
...
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
View file @
e15c6f2d
...
@@ -240,7 +240,7 @@ struct FusedMoeGemmKernel
...
@@ -240,7 +240,7 @@ struct FusedMoeGemmKernel
{
{
if
constexpr
(
UseUK
)
if
constexpr
(
UseUK
)
{
{
__shared__
CK_TILE_LDS_ADDR
ADataType
smem
[
GetSmemSize
()];
__shared__
CK_TILE_LDS_ADDR
char
smem
[
GetSmemSize
()];
IndexDataType
num_sorted_tiles
=
__builtin_amdgcn_readfirstlane
(
IndexDataType
num_sorted_tiles
=
__builtin_amdgcn_readfirstlane
(
*
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
num_sorted_tiles_ptr
));
*
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
num_sorted_tiles_ptr
));
...
@@ -275,7 +275,7 @@ struct FusedMoeGemmKernel
...
@@ -275,7 +275,7 @@ struct FusedMoeGemmKernel
index_t
expert_stride_0
=
kargs
.
intermediate_size
*
hidden_radio_0
*
kargs
.
hidden_size
;
index_t
expert_stride_0
=
kargs
.
intermediate_size
*
hidden_radio_0
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
kargs
.
intermediate_size
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
kargs
.
intermediate_size
*
kargs
.
hidden_size
;
__shared__
CK_TILE_LDS_ADDR
ADataType
smem
[
GetSmemSize
()];
__shared__
CK_TILE_LDS_ADDR
char
smem
[
GetSmemSize
()];
// note this is in unit of tile, need multiple tile size to get the index
// note this is in unit of tile, need multiple tile size to get the index
const
auto
[
sorted_tile_id
,
intermediate_tile_id
]
=
const
auto
[
sorted_tile_id
,
intermediate_tile_id
]
=
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
View file @
e15c6f2d
...
@@ -70,13 +70,21 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -70,13 +70,21 @@ struct FusedMoeGemmPipeline_FlatmmUk
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
{
return
32768
;
// constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize();
// constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize();
// constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
// constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
// constexpr index_t smem_bridge =
// constexpr index_t smem_bridge =
// BlockShape::Block_M0 * BlockShape::Block_N0;
// BlockShape::Block_M0 * BlockShape::Block_N0;
return
32768
;
//max(smem_0, max(smem_1, smem_bridge));
// return 32768;//max(smem_0, max(smem_1, smem_bridge));
constexpr
index_t
smem_0
=
Policy
::
template
GetUK_0
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_1
=
Policy
::
template
GetUK_1
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_bridge
=
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
*
sizeof
(
YDataType
);
// return max(smem_0, max(smem_1, smem_bridge));
return
max
(
smem_0
+
smem_1
,
smem_bridge
);
}
}
// this is the thread-offset along row/col
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetACoord
()
CK_TILE_HOST_DEVICE
static
auto
GetACoord
()
{
{
...
@@ -199,7 +207,10 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -199,7 +207,10 @@ struct FusedMoeGemmPipeline_FlatmmUk
threadIdx
.
x
%
(
BlockShape
::
Block_K0
/
kAlignmentA
)
*
kAlignmentA
;
threadIdx
.
x
%
(
BlockShape
::
Block_K0
/
kAlignmentA
)
*
kAlignmentA
;
},
},
number
<
row_ids_a
.
size
()
>
{});
number
<
row_ids_a
.
size
()
>
{});
// if(threadIdx.x==0)
// printf("row id %d\n", row_ids_a[0]);
if
(
row_ids_a
.
at
(
0
)
>=
kargs
.
num_tokens
)
return
;
auto
a_res
=
auto
a_res
=
make_wave_buffer_resource
(
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
),
make_wave_buffer_resource
(
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
),
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ADataType
));
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ADataType
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment