Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f912ca40
Commit
f912ca40
authored
Nov 21, 2024
by
“letaoqin”
Browse files
fix call indexing adaptor issue
parent
1561fc22
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
26 additions
and
22 deletions
+26
-22
include/ck_tile/core/algorithm/indexing_adaptor.hpp
include/ck_tile/core/algorithm/indexing_adaptor.hpp
+15
-12
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
...ile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
+2
-2
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp
...s/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp
+9
-8
No files found.
include/ck_tile/core/algorithm/indexing_adaptor.hpp
View file @
f912ca40
...
...
@@ -65,8 +65,8 @@ struct indexing_adaptor
CK_TILE_HOST_DEVICE
constexpr
indexing_adaptor
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
indexing_adaptor
(
const
IndexingType
*
idx
)
:
cached_idx_
(
idx
)
{}
const
IndexingType
*
cached_idx_
;
mutable
index_t
pre
UpI
ndex
=
0
;
mutable
index_t
pre
LowI
ndex
=
0
;
mutable
index_t
pre
_up_i
ndex
_
=
0
;
mutable
index_t
pre
_low_i
ndex
_
=
0
;
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
...
...
@@ -77,12 +77,14 @@ struct indexing_adaptor
idx_low
(
number
<
0
>
{})
=
*
(
cached_idx_
+
idx_up
[
number
<
0
>
{}]);
preUpIndex
=
idx_up
[
number
<
0
>
{}];
preLowIndex
=
idx_low
(
number
<
0
>
{});
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
)
pre_up_index_
=
idx_up
[
number
<
0
>
{}];
pre_low_index_
=
idx_low
(
number
<
0
>
{});
#if 0
if(threadIdx.x == 65 && blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0)
{
printf("\n first index from %d to %d \n", idx_up[number<0>{}], idx_low(number<0>{}));
}
#endif
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
...
...
@@ -96,14 +98,14 @@ struct indexing_adaptor
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
int
up_index
=
idx_diff_up
[
number
<
0
>
{}]
+
pre
UpI
ndex
;
int
up_index
=
idx_diff_up
[
number
<
0
>
{}]
+
pre
_up_i
ndex
_
;
int
low_index
=
*
(
cached_idx_
+
up_index
);
idx_diff_low
(
number
<
0
>
{})
=
low_index
-
pre
LowI
ndex
;
idx_diff_low
(
number
<
0
>
{})
=
low_index
-
pre
_low_i
ndex
_
;
pre
UpI
ndex
=
up_index
;
pre
LowI
ndex
=
low_index
;
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
)
pre
_up_i
ndex
_
=
up_index
;
pre
_low_i
ndex
_
=
low_index
;
#if 0
if(threadIdx.x ==
65
&& blockIdx.x == 0 && blockIdx.y ==
1
&& blockIdx.z == 0)
{
printf("\n index form %d to %d, diff from %d to %d \n",
up_index,
...
...
@@ -111,6 +113,7 @@ struct indexing_adaptor
idx_diff_up[number<0>{}],
idx_diff_low(number<0>{}));
}
#endif
// pass the diff to lower, but not changing the actually index
}
...
...
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
View file @
f912ca40
...
...
@@ -268,8 +268,8 @@ struct FusedMoeGemmGlKernel
auto
topk_weight
=
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
)[
sorted_token_id
];
const
index_t
*
sorted_token_ids_ptr
=
reinterpret_cast
<
const
index_t
*>
(
&
(
reinterpret_cast
<
const
index_t
*>
(
kargs
.
sorted_token_ids_ptr
)
[
sorted_token_id
]))
;
const
index_t
*
sorted_token_ids_ptr
=
reinterpret_cast
<
const
index_t
*>
(
kargs
.
sorted_token_ids_ptr
);
const
auto
a_window
=
[
&
]()
{
// A is already pre-padded in previous kernel
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp
View file @
f912ca40
...
...
@@ -104,19 +104,20 @@ struct FusedMoeGemmPipeline_FlatmmGl
ignore
=
hidden_size
;
ignore
=
intermediate_size
;
auto
a_copy_dram_window
=
make_tile_window
(
a_window_
.
get_bottom_tensor_view
(),
auto
a_copy_dram_window
=
make_tile_window
(
a_window_
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
a_window_
.
get_window_origin
(),
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>());
auto
a_dram
=
load_tile
(
a_copy_dram_window
);
#if 0
//check a matrix gather right or not
constexpr auto a_spans = decltype(a_dram)::get_distributed_spans();
int counter = 0;
sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxk){
constexpr auto i_j_idx = make_tuple(idxm, idxk);
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
){
if(threadIdx.x ==
65
&& blockIdx.x == 0 && blockIdx.y ==
1
&& blockIdx.z == 0){
counter = counter + 1;
index_t idm_0 = idxm.impl_.at(0);
index_t idn_0 = idxk.impl_.at(0);
...
...
@@ -124,8 +125,8 @@ struct FusedMoeGemmPipeline_FlatmmGl
}
});
});
ignore
=
a_
spans
;
#endif
ignore
=
a_
dram
;
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment