Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7f4d6f08
"vscode:/vscode.git/clone" did not exist on "03059eb0e7815806705e4b3d1937faa7db9d60fd"
Commit
7f4d6f08
authored
Dec 26, 2024
by
letaoqin
Browse files
add MakeLdsBlockDesc_O
parent
ce97a2af
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
19 deletions
+31
-19
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
+15
-18
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
+16
-1
No files found.
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
View file @
7f4d6f08
...
@@ -81,8 +81,7 @@ struct FusedMoeGemmPipeline_General
...
@@ -81,8 +81,7 @@ struct FusedMoeGemmPipeline_General
// shuffle C matrix
// shuffle C matrix
constexpr
index_t
smem_bridge
=
Policy
::
template
GetSmemSize_Bridge
<
Problem
>();
constexpr
index_t
smem_bridge
=
Policy
::
template
GetSmemSize_Bridge
<
Problem
>();
constexpr
index_t
smem_mat_o
=
constexpr
index_t
smem_mat_o
=
BlockShape
::
Block_N1
*
BlockShape
::
Block_K1
*
sizeof
(
float
);
BlockShape
::
Block_N1
*
BlockShape
::
Block_K1
*
sizeof
(
float
);
return
max
(
smem_mat_a
+
smem_mat_d
,
smem_bridge
,
smem_mat_o
);
return
max
(
smem_mat_a
+
smem_mat_d
,
smem_bridge
,
smem_mat_o
);
// return Policy::template GetSmemSize<Problem>();
// return Policy::template GetSmemSize<Problem>();
...
@@ -304,18 +303,15 @@ struct FusedMoeGemmPipeline_General
...
@@ -304,18 +303,15 @@ struct FusedMoeGemmPipeline_General
#endif
#endif
// add to LDS
// add to LDS
CK_TILE_LDS_ADDR
float
*
smem_o
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
float
*>
(
smem
);
CK_TILE_LDS_ADDR
float
*
smem_o
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
float
*>
(
smem
);
auto
o_lds_view
=
auto
o_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
make_naive_tensor_view
<
address_space_enum
::
lds
,
memory_operation_enum
::
set
>
(
smem_o
,
Policy
::
template
MakeLdsBlockDesc_O
<
Problem
>());
smem_o
,
auto
o_alds_win
=
make_tile_window
(
make_tuple
(
number
<
128
>
{},
number
<
32
>
{}),
o_lds_view
,
make_tuple
(
32
,
1
),
make_tuple
(
number
<
BlockShape
::
Block_K1
>
{},
number
<
BlockShape
::
Block_N1
>
{}),
number
<
8
>
{},
{
0
,
0
});
number
<
1
>
{});
auto
o_olds_win
=
make_tile_window
(
auto
o_alds_win
=
o_lds_view
,
make_tile_window
(
o_lds_view
,
make_tuple
(
number
<
128
>
{},
number
<
32
>
{}),
{
0
,
0
});
make_tuple
(
number
<
BlockShape
::
Block_M1
>
{},
number
<
BlockShape
::
Block_N1
>
{}),
auto
o_olds_win
=
make_tile_window
(
o_lds_view
,
make_tuple
(
number
<
32
>
{},
number
<
32
>
{}),
{
0
,
0
},
{
0
,
0
},
Policy
::
template
MakeGlobalTileDistribution_O
<
Problem
>());
Policy
::
template
MakeGlobalTileDistribution_O
<
Problem
>());
...
@@ -338,7 +334,8 @@ struct FusedMoeGemmPipeline_General
...
@@ -338,7 +334,8 @@ struct FusedMoeGemmPipeline_General
auto
o
=
cast_tile
<
ODataType
>
(
o0
);
auto
o
=
cast_tile
<
ODataType
>
(
o0
);
update_tile
(
o_window_
,
o
);
update_tile
(
o_window_
,
o
);
// restore pos
// restore pos
move_tile_window
(
o_olds_win
,
{
-
32
*
(
BlockShape
::
Repeat_K1
-
1
),
0
});
move_tile_window
(
o_olds_win
,
{
-
BlockShape
::
Block_M1
*
(
BlockShape
::
Repeat_K1
-
1
),
0
});
}
}
}
}
};
};
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
View file @
7f4d6f08
...
@@ -134,7 +134,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
...
@@ -134,7 +134,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
constexpr
index_t
M_wav
=
NumWarps
/
K_wav
;
constexpr
index_t
M_wav
=
NumWarps
/
K_wav
;
static_assert
(
MPerBlock
%
M_wav
==
0
,
"this tile size is too small please check"
);
static_assert
(
MPerBlock
%
M_wav
==
0
,
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
M_wav
;
constexpr
index_t
M_rep
=
MPerBlock
/
M_wav
;
static_assert
(
M_rep
<=
2
);
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
tile_distribution_encoding
<
sequence
<
1
>
,
sequence
<
1
>
,
...
@@ -152,6 +152,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
...
@@ -152,6 +152,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
static_assert
(
MPerBlock
%
(
M_lan
*
M_wav
)
==
0
,
static_assert
(
MPerBlock
%
(
M_lan
*
M_wav
)
==
0
,
"this tile size is too small please check"
);
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
(
M_lan
*
M_wav
);
constexpr
index_t
M_rep
=
MPerBlock
/
(
M_lan
*
M_wav
);
static_assert
(
M_rep
<=
2
);
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
tile_distribution_encoding
<
sequence
<
1
>
,
sequence
<
1
>
,
...
@@ -354,6 +355,20 @@ struct FusedMoeGemmPipelineGeneralPolicy
...
@@ -354,6 +355,20 @@ struct FusedMoeGemmPipelineGeneralPolicy
return
d_lds_block_desc
;
return
d_lds_block_desc
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsBlockDesc_O
()
{
constexpr
index_t
Block_N1
=
Problem
::
BlockShape
::
Block_N1
;
constexpr
index_t
Block_K1
=
Problem
::
BlockShape
::
Block_N1
;
constexpr
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
Block_K1
>
{},
number
<
Block_N1
>
{}),
make_tuple
(
number
<
Block_N1
>
{},
number
<
1
>
{}),
number
<
4
>
{},
number
<
1
>
{});
return
desc
;
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBridgeLdsBlockDesc
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBridgeLdsBlockDesc
()
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment