Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
0385cded
Commit
0385cded
authored
Nov 06, 2024
by
dummycoderfe
Browse files
[Ck_tile] moe set zero run ok, add size check and fix ref check
parent
bfe0120a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
9 additions
and
4 deletions
+9
-4
example/ck_tile/13_moe_sorting/moe_sorting.cpp
example/ck_tile/13_moe_sorting/moe_sorting.cpp
+3
-3
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
+5
-0
include/ck_tile/ops/moe_sorting/kernel/moe_sorting_kernel.hpp
...ude/ck_tile/ops/moe_sorting/kernel/moe_sorting_kernel.hpp
+1
-1
No files found.
example/ck_tile/13_moe_sorting/moe_sorting.cpp
View file @
0385cded
...
@@ -96,7 +96,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -96,7 +96,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
ck_tile
::
HostTensor
<
WeightType
>
sorted_weights_host
({
max_output_ids
},
{
1
});
ck_tile
::
HostTensor
<
WeightType
>
sorted_weights_host
({
max_output_ids
},
{
1
});
ck_tile
::
HostTensor
<
IndexType
>
expert_ids_host
({
max_output_ids
/
unit_size
},
{
1
});
ck_tile
::
HostTensor
<
IndexType
>
expert_ids_host
({
max_output_ids
/
unit_size
},
{
1
});
ck_tile
::
HostTensor
<
IndexType
>
sorted_id_cnt_host
({
1
},
{
1
});
ck_tile
::
HostTensor
<
IndexType
>
sorted_id_cnt_host
({
1
},
{
1
});
ck_tile
::
HostTensor
<
IndexType
>
moe_buf_host
({
moe_buf_size
},
{
1
});
ck_tile
::
HostTensor
<
float
>
moe_buf_host
({
moe_buf_size
},
{
1
});
ck_tile
::
FillUniformDistribution
<
WeightType
>
{
-
.5
f
,
.5
f
}(
weights_host
);
ck_tile
::
FillUniformDistribution
<
WeightType
>
{
-
.5
f
,
.5
f
}(
weights_host
);
ck_tile
::
FillUniformDistribution
<
WeightType
>
{
-
.5
f
,
.5
f
}(
moe_buf_host
);
ck_tile
::
FillUniformDistribution
<
WeightType
>
{
-
.5
f
,
.5
f
}(
moe_buf_host
);
...
@@ -127,7 +127,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -127,7 +127,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
unit_size
,
unit_size
,
experts
,
experts
,
topk
,
topk
,
moe_buf_size
};
static_cast
<
ck_tile
::
index_t
>
(
moe_buf_host
.
get_element_space_size_in_bytes
())
};
ck_tile
::
stream_config
sc
{
nullptr
,
ck_tile
::
stream_config
sc
{
nullptr
,
true
,
true
,
...
@@ -162,7 +162,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -162,7 +162,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
ck_tile
::
HostTensor
<
IndexType
>
sorted_ids_ref
({
max_output_ids
},
{
1
});
ck_tile
::
HostTensor
<
IndexType
>
sorted_ids_ref
({
max_output_ids
},
{
1
});
ck_tile
::
HostTensor
<
WeightType
>
sorted_weights_ref
({
max_output_ids
},
{
1
});
ck_tile
::
HostTensor
<
WeightType
>
sorted_weights_ref
({
max_output_ids
},
{
1
});
ck_tile
::
HostTensor
<
IndexType
>
expert_ids_ref
({
max_output_ids
/
unit_size
},
{
1
});
ck_tile
::
HostTensor
<
IndexType
>
expert_ids_ref
({
max_output_ids
/
unit_size
},
{
1
});
ck_tile
::
HostTensor
<
Index
Type
>
moe_buf_ref
({
moe_buf_size
},
{
1
});
ck_tile
::
HostTensor
<
Weight
Type
>
moe_buf_ref
({
moe_buf_size
},
{
1
});
moe_buf_ref
.
SetZero
();
moe_buf_ref
.
SetZero
();
int32_t
total_tokens_post_pad
=
0
;
int32_t
total_tokens_post_pad
=
0
;
...
...
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
View file @
0385cded
...
@@ -24,6 +24,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
...
@@ -24,6 +24,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
printf
(
"lds size exceed, only support experts <127
\n
"
);
printf
(
"lds size exceed, only support experts <127
\n
"
);
return
-
1
;
return
-
1
;
}
}
if
(
a
.
moe_buf_set_bytes
%
16
)
{
printf
(
"buf set size %d unaligned, must be multiple of 16
\n
"
,
a
.
moe_buf_set_bytes
);
return
-
1
;
}
using
index_t
=
ck_tile
::
index_t
;
using
index_t
=
ck_tile
::
index_t
;
using
ms_weight_type
=
float
;
using
ms_weight_type
=
float
;
index_t
smem_io_unroll_num
=
ck_tile
::
integer_divide_ceil
(
a
.
tokens
*
a
.
topk
,
64
);
index_t
smem_io_unroll_num
=
ck_tile
::
integer_divide_ceil
(
a
.
tokens
*
a
.
topk
,
64
);
...
...
include/ck_tile/ops/moe_sorting/kernel/moe_sorting_kernel.hpp
View file @
0385cded
...
@@ -107,7 +107,7 @@ struct MoeSortingKernel
...
@@ -107,7 +107,7 @@ struct MoeSortingKernel
const
index_t
offset
=
(
blockIdx
.
x
-
1
)
*
blockDim
.
x
+
threadIdx
.
x
;
const
index_t
offset
=
(
blockIdx
.
x
-
1
)
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
<
buf_bytes
/
16
)
if
(
offset
<
buf_bytes
/
16
)
{
{
buf
[
offset
]
=
uint8x16_t
(
0
)
;
buf
[
offset
]
=
uint8x16_t
{
0
}
;
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment