Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ce97a2af
Commit
ce97a2af
authored
Dec 25, 2024
by
letaoqin
Browse files
rewrite getsmemsize
parent
e1b457ec
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
26 additions
and
19 deletions
+26
-19
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
...ile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
+1
-3
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
+16
-14
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
+9
-2
No files found.
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
View file @
ce97a2af
...
@@ -228,7 +228,7 @@ struct FusedMoeGemmGlKernel
...
@@ -228,7 +228,7 @@ struct FusedMoeGemmGlKernel
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
{
// allocate LDS
// allocate LDS
//
__shared__ char smem
_ptr
[GetSmemSize()];
__shared__
CK_TILE_LDS_ADDR
char
smem
[
GetSmemSize
()];
IndexDataType
num_sorted_tiles
=
__builtin_amdgcn_readfirstlane
(
IndexDataType
num_sorted_tiles
=
__builtin_amdgcn_readfirstlane
(
*
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
num_sorted_tiles_ptr
));
*
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
num_sorted_tiles_ptr
));
constexpr
index_t
hidden_radio_0
=
IsGateOnly
?
1
:
2
;
constexpr
index_t
hidden_radio_0
=
IsGateOnly
?
1
:
2
;
...
@@ -236,8 +236,6 @@ struct FusedMoeGemmGlKernel
...
@@ -236,8 +236,6 @@ struct FusedMoeGemmGlKernel
index_t
expert_stride_0
=
kargs
.
intermediate_size
*
hidden_radio_0
*
kargs
.
hidden_size
;
index_t
expert_stride_0
=
kargs
.
intermediate_size
*
hidden_radio_0
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
kargs
.
intermediate_size
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
kargs
.
intermediate_size
*
kargs
.
hidden_size
;
__shared__
CK_TILE_LDS_ADDR
ADataType
smem
[
GetSmemSize
()];
// note this is in unit of tile, need multiple tile size to get the index(block_m and
// note this is in unit of tile, need multiple tile size to get the index(block_m and
// block_n)
// block_n)
const
auto
[
sorted_tile_id
,
intermediate_tile_id
]
=
const
auto
[
sorted_tile_id
,
intermediate_tile_id
]
=
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
View file @
ce97a2af
...
@@ -77,13 +77,14 @@ struct FusedMoeGemmPipeline_General
...
@@ -77,13 +77,14 @@ struct FusedMoeGemmPipeline_General
{
{
// matrix a or tokens smem
// matrix a or tokens smem
constexpr
index_t
smem_mat_a
=
GetSmemSizeA
();
constexpr
index_t
smem_mat_a
=
GetSmemSizeA
();
constexpr
index_t
smem_mat_d
=
constexpr
index_t
smem_mat_d
=
Policy
::
template
GetSmemSize_G
<
Problem
>();
BlockShape
::
Block_N0
*
BlockShape
::
Block_K0
*
sizeof
(
GDataType
);
// shuffle C matrix
// shuffle C matrix
constexpr
index_t
smem_bridge
=
constexpr
index_t
smem_bridge
=
Policy
::
template
GetSmemSize_Bridge
<
Problem
>();
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
*
sizeof
(
YDataType
);
return
max
(
smem_mat_a
+
smem_mat_d
,
smem_bridge
);
constexpr
index_t
smem_mat_o
=
BlockShape
::
Block_N1
*
BlockShape
::
Block_K1
*
sizeof
(
float
);
return
max
(
smem_mat_a
+
smem_mat_d
,
smem_bridge
,
smem_mat_o
);
// return Policy::template GetSmemSize<Problem>();
// return Policy::template GetSmemSize<Problem>();
}
}
...
@@ -131,19 +132,19 @@ struct FusedMoeGemmPipeline_General
...
@@ -131,19 +132,19 @@ struct FusedMoeGemmPipeline_General
index_t
/*intermediate_size*/
,
index_t
/*intermediate_size*/
,
CWindow
&
/*c_window_*/
)
CWindow
&
/*c_window_*/
)
{
{
CK_TILE_LDS_ADDR
ADataType
*
smem_
0
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
smem
);
CK_TILE_LDS_ADDR
ADataType
*
smem_
a
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
smem
);
CK_TILE_LDS_ADDR
GDataType
*
smem_
1
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
GDataType
*>
(
CK_TILE_LDS_ADDR
GDataType
*
smem_
g
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
GDataType
*>
(
smem_
0
+
GetSmemSizeA
()
/
sizeof
(
ADataType
));
smem_
a
+
GetSmemSizeA
()
/
sizeof
(
ADataType
));
auto
a_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
a_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_
0
,
Policy
::
template
MakeLdsBlockDesc_A
<
Problem
>());
smem_
a
,
Policy
::
template
MakeLdsBlockDesc_A
<
Problem
>());
auto
a_lds_win
=
make_tile_window
(
auto
a_lds_win
=
make_tile_window
(
a_lds_view
,
a_lds_view
,
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
{
0
,
0
});
{
0
,
0
});
auto
g_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
g_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_
1
,
Policy
::
template
MakeLdsBlockDesc_G
<
Problem
>());
smem_
g
,
Policy
::
template
MakeLdsBlockDesc_G
<
Problem
>());
auto
g_lds_win
=
make_tile_window
(
auto
g_lds_win
=
make_tile_window
(
g_lds_view
,
g_lds_view
,
make_tuple
(
number
<
BlockShape
::
Block_N0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
make_tuple
(
number
<
BlockShape
::
Block_N0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
...
@@ -235,8 +236,9 @@ struct FusedMoeGemmPipeline_General
...
@@ -235,8 +236,9 @@ struct FusedMoeGemmPipeline_General
// store_tile(c_window_, y_pre);
// store_tile(c_window_, y_pre);
// }
// }
// save to lds
// save to lds
CK_TILE_LDS_ADDR
ADataType
*
smem_y
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
YDataType
*>
(
smem
);
auto
bridge_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
bridge_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_
0
,
Policy
::
template
MakeBridgeLdsBlockDesc
<
Problem
>());
smem_
y
,
Policy
::
template
MakeBridgeLdsBlockDesc
<
Problem
>());
auto
bridge_slds_win
=
auto
bridge_slds_win
=
make_tile_window
(
bridge_lds_view
,
make_tile_window
(
bridge_lds_view
,
Policy
::
template
MakeBridgeLdsBlockDesc
<
Problem
>().
get_lengths
(),
Policy
::
template
MakeBridgeLdsBlockDesc
<
Problem
>().
get_lengths
(),
...
@@ -285,7 +287,7 @@ struct FusedMoeGemmPipeline_General
...
@@ -285,7 +287,7 @@ struct FusedMoeGemmPipeline_General
{
{
for(int i = 0; i < 16; i++)
for(int i = 0; i < 16; i++)
{
{
printf("\n smem_
0
[%d]: %f ", i, type_convert<float>(smem_
0
[i]));
printf("\n smem_
a
[%d]: %f ", i, type_convert<float>(smem_
a
[i]));
}
}
}
}
//store_tile(c_window_, y);
//store_tile(c_window_, y);
...
@@ -301,10 +303,10 @@ struct FusedMoeGemmPipeline_General
...
@@ -301,10 +303,10 @@ struct FusedMoeGemmPipeline_General
PrintMem(d,"D",0);
PrintMem(d,"D",0);
#endif
#endif
// add to LDS
// add to LDS
CK_TILE_LDS_ADDR
float
*
smem_
3
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
float
*>
(
smem
);
CK_TILE_LDS_ADDR
float
*
smem_
o
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
float
*>
(
smem
);
auto
o_lds_view
=
auto
o_lds_view
=
make_naive_tensor_view
<
address_space_enum
::
lds
,
memory_operation_enum
::
set
>
(
make_naive_tensor_view
<
address_space_enum
::
lds
,
memory_operation_enum
::
set
>
(
smem_
3
,
smem_
o
,
make_tuple
(
number
<
128
>
{},
number
<
32
>
{}),
make_tuple
(
number
<
128
>
{},
number
<
32
>
{}),
make_tuple
(
32
,
1
),
make_tuple
(
32
,
1
),
number
<
8
>
{},
number
<
8
>
{},
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
View file @
ce97a2af
...
@@ -94,14 +94,21 @@ struct FusedMoeGemmPipelineGeneralPolicy
...
@@ -94,14 +94,21 @@ struct FusedMoeGemmPipelineGeneralPolicy
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize_A
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize_A
()
{
{
constexpr
auto
a_lds_desc
=
MakeLdsBlockDesc_A
<
Problem
>
();
constexpr
auto
a_lds_desc
=
MakeLdsBlockDesc_A
<
Problem
>
();
return
a_lds_desc
.
get_element_space_size
();
return
a_lds_desc
.
get_element_space_size
()
*
sizeof
(
typename
Problem
::
ADataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize_G
()
{
constexpr
auto
g_lds_desc
=
MakeLdsBlockDesc_G
<
Problem
>
();
return
g_lds_desc
.
get_element_space_size
()
*
sizeof
(
typename
Problem
::
GDataType
);
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize_Bridge
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize_Bridge
()
{
{
constexpr
auto
bridge_lds_desc
=
MakeBridgeLdsBlockDesc
<
Problem
>
();
constexpr
auto
bridge_lds_desc
=
MakeBridgeLdsBlockDesc
<
Problem
>
();
return
bridge_lds_desc
.
get_element_space_size
();
return
bridge_lds_desc
.
get_element_space_size
()
*
sizeof
(
typename
Problem
::
YDataType
)
;
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment