Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e1b457ec
Commit
e1b457ec
authored
Dec 23, 2024
by
letaoqin
Browse files
g ad d add pading
parent
e97fdbc3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
20 additions
and
21 deletions
+20
-21
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
...ile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
+12
-2
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
...ude/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
+1
-1
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
+7
-12
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
+0
-6
No files found.
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
View file @
e1b457ec
...
...
@@ -292,9 +292,14 @@ struct FusedMoeGemmGlKernel
number
<
Pipeline
::
kAlignmentG
>
{},
number
<
1
>
{});
const
auto
g_
window_
=
make_tile_windo
w
(
const
auto
g_
view_1_
=
pad_tensor_vie
w
(
g_view_
,
make_tuple
(
number
<
BlockShape
::
Block_N0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
sequence
<
PadIntermediateSize
,
PadHiddenSize
>
{});
const
auto
g_window_
=
make_tile_window
(
g_view_1_
,
make_tuple
(
number
<
BlockShape
::
Block_N0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
{
idx_n0
,
0
});
return
g_window_
;
...
...
@@ -328,9 +333,14 @@ struct FusedMoeGemmGlKernel
number
<
Pipeline
::
kAlignmentD
>
{},
number
<
1
>
{});
const
auto
d_
window_
=
make_tile_windo
w
(
const
auto
d_
view_1_
=
pad_tensor_vie
w
(
d_view_
,
make_tuple
(
number
<
BlockShape
::
Block_N1
>
{},
number
<
BlockShape
::
Block_K1
>
{}),
sequence
<
PadHiddenSize
,
PadIntermediateSize
>
{});
const
auto
d_window_
=
make_tile_window
(
d_view_1_
,
make_tuple
(
number
<
BlockShape
::
Block_N1
>
{},
number
<
BlockShape
::
Block_K1
>
{}),
{
0
,
idx_n0
});
return
d_window_
;
}();
...
...
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
View file @
e1b457ec
...
...
@@ -391,7 +391,7 @@ struct FusedMoeGemmKernel
number
<
Pipeline
::
kAlignmentO
>
{},
number
<
1
>
{});
//
g
at
h
er is here
//
sc
at
t
er is here
auto
o_scatter_view_
=
transform_tensor_view
(
o_view_
,
make_tuple
(
make_indexing_transform
(
kargs
.
num_tokens
,
token_id
),
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
View file @
e1b457ec
...
...
@@ -71,9 +71,7 @@ struct FusedMoeGemmPipeline_General
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeA
()
{
// matrix a or tokens smem
constexpr
index_t
smem_mat_a
=
BlockShape
::
Block_M0
*
BlockShape
::
Block_K0
*
sizeof
(
ADataType
);
return
smem_mat_a
;
return
Policy
::
template
GetSmemSize_A
<
Problem
>();
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
...
...
@@ -131,11 +129,8 @@ struct FusedMoeGemmPipeline_General
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
hidden_size
,
index_t
/*intermediate_size*/
,
CWindow
&
c_window_
)
CWindow
&
/*
c_window_
*/
)
{
ignore
=
c_window_
;
ignore
=
hidden_size
;
ignore
=
w_window_
;
CK_TILE_LDS_ADDR
ADataType
*
smem_0
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
smem
);
CK_TILE_LDS_ADDR
GDataType
*
smem_1
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
GDataType
*>
(
smem_0
+
GetSmemSizeA
()
/
sizeof
(
ADataType
));
...
...
@@ -234,11 +229,11 @@ struct FusedMoeGemmPipeline_General
#if 0
PrintMem(y_pre, "Y_pre", 0);
#endif
if
(
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
)
{
block_sync_lds
();
store_tile
(
c_window_
,
y_pre
);
}
//
if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
//
{
//
block_sync_lds();
//
store_tile(c_window_, y_pre);
//
}
// save to lds
auto
bridge_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeBridgeLdsBlockDesc
<
Problem
>());
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
View file @
e1b457ec
...
...
@@ -312,12 +312,6 @@ struct FusedMoeGemmPipelineGeneralPolicy
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
// constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
// make_tuple(number<Block_M>{}, number<Block_K>{}),
// make_tuple(number<Block_K>{}, number<1>{}),
// number<8>{},
// number<1>{});
return
a_lds_block_desc
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment