Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1ba8a08f
Commit
1ba8a08f
authored
Aug 23, 2024
by
carlushuang
Browse files
update tmp work
parent
bf214665
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
348 additions
and
267 deletions
+348
-267
example/ck_tile/05_moe/fused_moe/kernel/fused_moe_kernel.hpp
example/ck_tile/05_moe/fused_moe/kernel/fused_moe_kernel.hpp
+2
-3
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline.hpp
.../ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline.hpp
+18
-16
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline_policy.hpp
...e/05_moe/fused_moe/pipeline/fused_moe_pipeline_policy.hpp
+313
-234
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_traits.hpp
...le/ck_tile/05_moe/fused_moe/pipeline/fused_moe_traits.hpp
+15
-14
No files found.
example/ck_tile/05_moe/fused_moe/kernel/fused_moe_kernel.hpp
View file @
1ba8a08f
...
@@ -252,7 +252,7 @@ struct FusedMoeKernel
...
@@ -252,7 +252,7 @@ struct FusedMoeKernel
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
{
// allocate LDS
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
//
__shared__ char smem_ptr[GetSmemSize()];
ck_tile
::
index_t
num_sorted_tiles
=
__builtin_amdgcn_readfirstlane
(
ck_tile
::
index_t
num_sorted_tiles
=
__builtin_amdgcn_readfirstlane
(
*
reinterpret_cast
<
const
ck_tile
::
index_t
*>
(
kargs
.
num_sorted_tiles_ptr
));
*
reinterpret_cast
<
const
ck_tile
::
index_t
*>
(
kargs
.
num_sorted_tiles_ptr
));
ck_tile
::
index_t
tile_id
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
;);
ck_tile
::
index_t
tile_id
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
;);
...
@@ -436,8 +436,7 @@ struct FusedMoeKernel
...
@@ -436,8 +436,7 @@ struct FusedMoeKernel
u_gtile_window
,
u_gtile_window
,
d_gtile_window
,
d_gtile_window
,
o_gtile_window
,
o_gtile_window
,
scale
,
scale
);
smem_ptr
);
tile_id
+=
gridDim
.
x
;
tile_id
+=
gridDim
.
x
;
}
}
...
...
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline.hpp
View file @
1ba8a08f
...
@@ -117,7 +117,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -117,7 +117,7 @@ struct BlockFmhaPipelineQRKSVSAsync
// this is the thread-offset along row/col
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetAIndex
()
CK_TILE_HOST_DEVICE
static
auto
GetAIndex
()
{
{
constexpr
auto
a_dist
=
Policy
::
template
Make
A
GlobalTileDistribution
<
Problem
>();
constexpr
auto
a_dist
=
Policy
::
template
MakeGlobalTileDistribution
_A
<
Problem
>();
const
auto
a_coord
=
a_dist
.
calculate_index
();
const
auto
a_coord
=
a_dist
.
calculate_index
();
return
a_coord
;
return
a_coord
;
}
}
...
@@ -142,7 +142,8 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -142,7 +142,8 @@ struct BlockFmhaPipelineQRKSVSAsync
OGlobalTensorView
&
o_gtile_window_tmp
,
OGlobalTensorView
&
o_gtile_window_tmp
,
// const void * sorted_weight_ptr,
// const void * sorted_weight_ptr,
ScaleDataType
scale
,
ScaleDataType
scale
,
void
*
smem_ptr
,
CK_TILE_LDS_ADDR
void
*
smem_0
,
CK_TILE_LDS_ADDR
void
*
smem_1
,
index_t
dim_size
,
index_t
dim_size
,
index_t
hidden_size
)
index_t
hidden_size
)
{
{
...
@@ -153,25 +154,25 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -153,25 +154,25 @@ struct BlockFmhaPipelineQRKSVSAsync
make_tile_window
(
a_gtile_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
a_gtile_window_tmp
.
get_bottom_tensor_view
(),
a_gtile_window_tmp
.
get_window_lengths
(),
a_gtile_window_tmp
.
get_window_lengths
(),
a_gtile_window_tmp
.
get_window_origin
(),
a_gtile_window_tmp
.
get_window_origin
(),
Policy
::
template
Make
A
GlobalTileDistribution
<
Problem
>());
Policy
::
template
MakeGlobalTileDistribution
_A
<
Problem
>());
auto
g_gtile_window
=
auto
g_gtile_window
=
make_tile_window
(
g_gtile_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
g_gtile_window_tmp
.
get_bottom_tensor_view
(),
g_gtile_window_tmp
.
get_window_lengths
(),
g_gtile_window_tmp
.
get_window_lengths
(),
g_gtile_window_tmp
.
get_window_origin
(),
g_gtile_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeG
G
lobalTileDistribution
<
Problem
>());
Policy
::
template
MakeGlobalTileDistribution
_G
<
Problem
>());
auto
u_gtile_window
=
auto
u_gtile_window
=
make_tile_window
(
u_gtile_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
u_gtile_window_tmp
.
get_bottom_tensor_view
(),
u_gtile_window_tmp
.
get_window_lengths
(),
u_gtile_window_tmp
.
get_window_lengths
(),
u_gtile_window_tmp
.
get_window_origin
(),
u_gtile_window_tmp
.
get_window_origin
(),
Policy
::
template
Make
U
GlobalTileDistribution
<
Problem
>());
Policy
::
template
MakeGlobalTileDistribution
_U
<
Problem
>());
auto
d_gtile_window
=
auto
d_gtile_window
=
make_tile_window
(
d_gtile_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
d_gtile_window_tmp
.
get_bottom_tensor_view
(),
d_gtile_window_tmp
.
get_window_lengths
(),
d_gtile_window_tmp
.
get_window_lengths
(),
d_gtile_window_tmp
.
get_window_origin
(),
d_gtile_window_tmp
.
get_window_origin
(),
Policy
::
template
Make
D
GlobalTileDistribution
<
Problem
>());
Policy
::
template
MakeGlobalTileDistribution
_D
<
Problem
>());
auto
o_gtile_window
=
auto
o_gtile_window
=
make_tile_window
(
o_gtile_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
o_gtile_window_tmp
.
get_bottom_tensor_view
(),
...
@@ -187,12 +188,13 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -187,12 +188,13 @@ struct BlockFmhaPipelineQRKSVSAsync
auto
a_smem_ptr
=
reinterpret_cast
<
ADataType
*>
(
smem_ptr
)
+
a_smem_offset
;
auto
a_smem_ptr
=
reinterpret_cast
<
ADataType
*>
(
smem_ptr
)
+
a_smem_offset
;
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
smem_0_window
=
make_tile_window
(
a_smem_ptr
,
Policy
::
template
MakeALdsStoreBlockDescriptor
<
Problem
>()),
make_tensor_view
<
address_space_enum
::
lds
>
(
Policy
::
template
MakeALdsStoreBlockDescriptor
<
Problem
>().
get_lengths
(),
smem_0
,
Policy
::
template
MakeLdsStoreBlockDescriptor_A
<
Problem
>()),
{
0
,
0
});
Policy
::
template
MakeLdsStoreBlockDescriptor_A
<
Problem
>().
get_lengths
(),
{
0
,
0
});
async_load_tile
_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{}))
,
a_gtile_window
);
async_load_tile
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})));
for
(
index_t
i_0
=
0
;
i_0
<
loops_0
;
i_0
++
)
{}
for
(
index_t
i_0
=
0
;
i_0
<
loops_0
;
i_0
++
)
{}
}
}
...
@@ -351,8 +353,8 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -351,8 +353,8 @@ struct BlockFmhaPipelineQRKSVSAsync
store_tile
(
lse_dram_window_tmp
,
tile_elementwise_in
(
lse_element_func
,
lse
));
store_tile
(
lse_dram_window_tmp
,
tile_elementwise_in
(
lse_element_func
,
lse
));
}
}
buffer_load_fence
(
0
);
// rocm-6.1, if whole tile is masked out, need to fence(0)
buffer_load_fence
_raw
(
0
);
// rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it
// Note: here occ are all cleard, return it
return
o_acc
;
return
o_acc
;
...
@@ -403,7 +405,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -403,7 +405,7 @@ struct BlockFmhaPipelineQRKSVSAsync
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_fence
(
k_dram_window
.
get_num_access
(),
q
.
get_thread_buffer
());
buffer_load_fence
_raw
(
k_dram_window
.
get_num_access
(),
q
.
get_thread_buffer
());
(
void
)
q_element_func
;
// ??? rocm-6.x if use q element func will have scratch on hdim=64/32
(
void
)
q_element_func
;
// ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
...
@@ -428,7 +430,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -428,7 +430,7 @@ struct BlockFmhaPipelineQRKSVSAsync
if
constexpr
(
i_k0
<
k0_loops
-
1
)
if
constexpr
(
i_k0
<
k0_loops
-
1
)
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
async_load_fence
(
k_dram_window
.
get_num_access
());
async_load_fence
_raw
(
k_dram_window
.
get_num_access
());
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
gemm_0
(
s_acc
,
gemm_0
(
s_acc
,
...
@@ -450,7 +452,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -450,7 +452,7 @@ struct BlockFmhaPipelineQRKSVSAsync
if
constexpr
(
k0_loops
<=
2
)
if
constexpr
(
k0_loops
<=
2
)
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
async_load_fence
();
async_load_fence
_raw
();
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_s_barrier
();
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
...
...
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline_policy.hpp
View file @
1ba8a08f
This diff is collapsed.
Click to expand it.
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_traits.hpp
View file @
1ba8a08f
...
@@ -12,21 +12,22 @@
...
@@ -12,21 +12,22 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
bool
GateUpPreShuffled_
=
false
,
enum
class
FusedMoePermuteStyle
bool
DownPreShuffled_
=
false
,
{
index_t
NumPrefetchA_
=
2
,
// permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
index_t
NumPrefetchG_
=
2
,
// permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
index_t
NumPrefetchU_
=
2
,
permute_b_nr_kr_kw_nw_kv
=
2
,
// 0,1,3,4,2,5
index_t
NumPrefetchD_
=
2
,
permute_b_nr_kr_waveflatten
=
permute_b_nr_kr_kw_nw_kv
,
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
no_permute
=
999
,
};
template
<
bool
DownPreShuffled_
=
false
,
FusedMoePermuteStyle
PermuteStyle_
=
FusedMoePermuteStyle
::
permute_b_nr_kr_kw_nw_kv
,
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
struct
FusedMoeTraits
struct
FusedMoeTraits
{
{
static
constexpr
bool
GateUpPreShuffled
=
GateUpPreShuffled_
;
static
constexpr
bool
DownPreShuffled
=
DownPreShuffled_
;
static
constexpr
bool
DownPreShuffled
=
DownPreShuffled_
;
static
constexpr
FusedMoePermuteStyle
PermuteStyle
=
PermuteStyle_
;
static
constexpr
index_t
NumPrefetchA
=
NumPrefetchA_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
static
constexpr
index_t
NumPrefetchG
=
NumPrefetchG_
;
static
constexpr
index_t
NumPrefetchU
=
NumPrefetchU_
;
static
constexpr
index_t
NumPrefetchD
=
NumPrefetchD_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment