Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f09dc1f3
Commit
f09dc1f3
authored
Nov 07, 2024
by
carlushuang
Browse files
compiler ok
parent
3bb718ad
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
285 additions
and
107 deletions
+285
-107
example/ck_tile/15_fused_moe/CMakeLists.txt
example/ck_tile/15_fused_moe/CMakeLists.txt
+2
-0
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+1
-0
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+5
-0
include/ck_tile/core/utility/static_counter.hpp
include/ck_tile/core/utility/static_counter.hpp
+116
-0
include/ck_tile/host/reference/reference_fused_moe.hpp
include/ck_tile/host/reference/reference_fused_moe.hpp
+2
-2
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp
.../ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp
+93
-103
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
...sed_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
+35
-1
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
...e/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
+11
-0
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
+14
-0
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
...e/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
+1
-1
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
+5
-0
No files found.
example/ck_tile/15_fused_moe/CMakeLists.txt
View file @
f09dc1f3
...
...
@@ -11,5 +11,7 @@ set(TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list
(
APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal
)
list
(
APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_BUFFER_LOAD_AGPR=1
)
# TODO: enable load to a
list
(
APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker
)
target_compile_options
(
${
TILE_EXAPMLE_FUSED_MOE
}
PRIVATE
${
TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS
}
)
include/ck_tile/core.hpp
View file @
f09dc1f3
...
...
@@ -62,6 +62,7 @@
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/reduce_operator.hpp"
#include "ck_tile/core/utility/static_counter.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
...
...
include/ck_tile/core/arch/amd_buffer_addressing.hpp
View file @
f09dc1f3
...
...
@@ -888,6 +888,11 @@ CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0)
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
CK_TILE_DEVICE
auto
async_load_fence_raw
(
index_t
cnt
=
0
)
{
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
// buffer load i8
CK_TILE_DEVICE_EXTERN
int8_t
llvm_amdgcn_raw_buffer_load_i8
(
int32x4_t
srsrc
,
...
...
include/ck_tile/core/utility/static_counter.hpp
0 → 100644
View file @
f09dc1f3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace
ck_tile
{
template
<
typename
Context
,
index_t
Start
=
0
,
index_t
Step
=
1
>
struct
static_counter
{
public:
template
<
typename
Unique
>
static
constexpr
index_t
next
()
{
return
next
<
Unique
>
(
0
)
*
Step
+
Start
;
}
template
<
unsigned
long
long
>
static
constexpr
index_t
next
()
{
struct
Unique
{
};
return
next
<
Unique
>
(
0
)
*
Step
+
Start
;
}
template
<
typename
Unique
>
static
constexpr
index_t
current
()
{
return
current
<
Unique
>
(
0
)
*
Step
+
Start
;
}
template
<
unsigned
long
long
>
static
constexpr
index_t
current
()
{
struct
Unique
{
};
return
current
<
Unique
>
(
0
)
*
Step
+
Start
;
}
private:
template
<
index_t
I
>
struct
slot
{
_Pragma
(
"GCC diagnostic push"
);
_Pragma
(
"GCC diagnostic ignored
\"
-Wundefined-internal
\"
"
);
friend
constexpr
bool
slot_allocated
(
slot
<
I
>
);
_Pragma
(
"GCC diagnostic pop"
);
};
template
<
index_t
I
>
struct
allocate_slot
{
friend
constexpr
bool
slot_allocated
(
slot
<
I
>
)
{
return
true
;
}
enum
{
value
=
I
};
};
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
// the overload set...
template
<
typename
Unique
,
index_t
I
=
0
,
bool
=
slot_allocated
(
slot
<
I
>())
>
static
constexpr
index_t
next
(
index_t
)
{
return
next
<
Unique
,
I
+
1
>
(
0
);
}
// ...And this function will be used, instead, which will define slot_allocated(slot<I>) via
// allocate_slot<I>.
template
<
typename
Unique
,
index_t
I
=
0
>
static
constexpr
index_t
next
(
double
)
{
return
allocate_slot
<
I
>::
value
;
}
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
// the overload set...
template
<
typename
Unique
,
index_t
I
=
Start
,
bool
=
slot_allocated
(
slot
<
I
>())
>
static
constexpr
index_t
current
(
index_t
)
{
return
current
<
Unique
,
I
+
1
>
(
0
);
}
// ...And this function will be used, instead, which will return the current counter, or assert
// in case next() hasn't been called yet.
template
<
typename
Unique
,
index_t
I
=
Start
>
static
constexpr
index_t
current
(
double
)
{
static_assert
(
I
!=
0
,
"You must invoke next() first"
);
return
I
-
1
;
}
};
namespace
impl
{
template
<
int
I
>
struct
static_counter_uniq_
;
}
#define MAKE_SC() \
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>> {}
#define MAKE_SC_WITH(start_, step_) \
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>, start_, step_> {}
#define NEXT_SC(c_) c_.next<__COUNTER__>()
#define NEXT_SCI(c_, static_i_) c_.next<__COUNTER__ + static_i_>()
// Usage:
// constexpr auto c = MAKE_SC()
// NEXT_SC(c) // -> constexpr 0
// NEXT_SC(c) // -> constexpr 1
// NEXT_SC(c) // -> constexpr 2
}
// namespace ck_tile
include/ck_tile/host/reference/reference_fused_moe.hpp
View file @
f09dc1f3
...
...
@@ -97,7 +97,7 @@ void reference_fused_moe(
int
max_num_tokens_padded
=
topk
*
tokens
+
experts
*
(
block_m
-
1
);
// assert();
auto
f
=
[
&
](
auto
i_flatten
)
{
ck_tile
::
index_t
i_tile
=
i_flatten
/
block_m
;
ck_tile
::
index_t
i_tile
=
i_flatten
/
block_m
;
if
(
i_tile
>=
num_sorted_tiles
)
return
;
ck_tile
::
index_t
i_expert
=
sorted_expert_ids_host
.
mData
[
i_tile
];
...
...
@@ -136,7 +136,7 @@ void reference_fused_moe(
{
AccDataType
tmp
;
Activation
{}(
tmp
,
acc_0
(
0
,
i_n
));
y
(
0
,
i_n
)
=
tmp
*
acc_0
(
0
,
i_n
+
hidden_size
);
// TODO: elementwise mul
y
(
0
,
i_n
)
=
tmp
*
acc_0
(
0
,
i_n
+
hidden_size
);
// TODO: elementwise mul
}
}
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp
View file @
f09dc1f3
...
...
@@ -156,10 +156,10 @@ struct FusedMoeGemmPipeline_Flatmm
using
g_thread_type
=
decltype
(
load_tile
(
g_win
));
using
d_thread_type
=
decltype
(
load_tile
(
d_win
));
//
using WarpGemm0 = Policy::template GetWarpGemm0<Problem>();
//
using WarpGemm1 = Policy::template GetWarpGemm1<Problem>();
//
auto warp_gemm_0 = WarpGemm0{};
//
auto warp_gemm_1 = WarpGemm1{};
using
WarpGemm0
=
decltype
(
Policy
::
template
GetWarpGemm0
<
Problem
>()
)
;
using
WarpGemm1
=
decltype
(
Policy
::
template
GetWarpGemm1
<
Problem
>()
)
;
auto
warp_gemm_0
=
WarpGemm0
{};
auto
warp_gemm_1
=
WarpGemm1
{};
// issues_warps_lanes
auto
a_sst_win0
=
...
...
@@ -175,7 +175,7 @@ struct FusedMoeGemmPipeline_Flatmm
{
0
,
0
,
0
});
// m*k
auto
a_sld_win0
=
[
&
]()
{
using
WG
=
decltype
(
Policy
::
template
GetWarpGemm0
<
Problem
>())
;
using
WG
=
WarpGemm0
;
constexpr
auto
a_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
BlockShape
::
Repeat_M0
,
BlockShape
::
WarpPerBlock_M0
>
,
...
...
@@ -196,7 +196,7 @@ struct FusedMoeGemmPipeline_Flatmm
// m*k
auto
a_sld_win1
=
[
&
]()
{
using
WG
=
decltype
(
Policy
::
template
GetWarpGemm0
<
Problem
>())
;
using
WG
=
WarpGemm0
;
constexpr
auto
a_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
BlockShape
::
Repeat_M0
,
BlockShape
::
WarpPerBlock_M0
>
,
...
...
@@ -242,10 +242,12 @@ struct FusedMoeGemmPipeline_Flatmm
constexpr
auto
issues_d
=
number
<
d_win
.
get_num_of_access
()
>
{};
constexpr
auto
issues_o
=
number
<
o_win
.
get_num_of_access
()
>
{};
constexpr
auto
issues_gemm0
=
number
<
BlockShape
::
Repeat_M0
*
BlockShape
::
Repeat_N0
*
BlockShape
::
Repeat_K0
>
{};
number
<
BlockShape
::
Repeat_M0
*
BlockShape
::
Repeat_N0
*
BlockShape
::
Repeat_K0
*
warp_gemm_0
.
get_num_of_access
()
>
{};
constexpr
auto
issues_gemm1
=
number
<
BlockShape
::
Repeat_M1
*
BlockShape
::
Repeat_N1
*
BlockShape
::
Repeat_K1
>
{};
constexpr
auto
issues_sld_a
=
number
<
a_sld_win0
.
get_num_of_access
()
>
{};
number
<
BlockShape
::
Repeat_M1
*
BlockShape
::
Repeat_N1
*
BlockShape
::
Repeat_K1
*
warp_gemm_1
.
get_num_of_access
()
>
{};
// constexpr auto issues_sld_a = number<a_sld_win0.get_num_of_access()>{};
const
index_t
num_blocks_k0
=
(
hidden_size
+
BlockShape
::
Block_K0
-
1
)
/
BlockShape
::
Block_K0
;
...
...
@@ -284,11 +286,9 @@ struct FusedMoeGemmPipeline_Flatmm
}
load_tile_raw
(
g_
,
g_win
,
i_access
,
FALSE
,
PreNop
{});
};
auto
move_g
=
[
&
]()
{
move_tile_window
(
g_win
,
{
number
<
0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
0
>
{}});
};
auto
move_g
=
[
&
]()
{
move_tile_window
(
g_win
,
{
number
<
0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
0
>
{}});
};
statically_indexed_array
<
d_thread_type
,
2
>
ds
;
auto
gld_d
=
[
&
]
<
typename
PreNop
=
bool_constant
<
false
>>
(
...
...
@@ -314,16 +314,17 @@ struct FusedMoeGemmPipeline_Flatmm
// clang-format off
auto
gemm_0
=
[
&
]
<
typename
PostNop
=
bool_constant
<
false
>>
(
auto
&
t_c
,
auto
&
t_a
,
auto
&
t_b
,
auto
i_access
,
PostNop
=
{})
{
auto
warp_gemm
=
Policy
::
template
GetWarpGemm0
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
warp_gemm
)
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
warp_gemm_0
)
>
;
constexpr
auto
repeat_sub
=
WarpGemm
::
get_num_of_access
();
constexpr
auto
repeat_m
=
BlockShape
::
Repeat_M0
;
// constexpr auto repeat_n = BlockShape::Repeat_N0;
constexpr
auto
repeat_k
=
BlockShape
::
Repeat_K0
;
// loop order n->m->k
constexpr
auto
i_k
=
i_access
%
repeat_k
;
constexpr
auto
i_m
=
(
i_access
/
repeat_k
)
%
repeat_m
;
constexpr
auto
i_n
=
(
i_access
/
repeat_k
)
/
repeat_m
;
constexpr
auto
i_sub
=
i_access
%
repeat_sub
;
constexpr
auto
i_k
=
(
i_access
/
repeat_sub
)
%
repeat_k
;
constexpr
auto
i_m
=
(
i_access
/
(
repeat_sub
*
repeat_k
))
%
repeat_m
;
constexpr
auto
i_n
=
(
i_access
/
(
repeat_sub
*
repeat_k
))
/
repeat_m
;
using
AWarpTensor
=
typename
WarpGemm
::
AWarpTensor
;
using
BWarpTensor
=
typename
WarpGemm
::
BWarpTensor
;
...
...
@@ -355,7 +356,7 @@ struct FusedMoeGemmPipeline_Flatmm
merge_sequences
(
sequence
<
i_m
,
i_n
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
W
arp
G
emm
{}
(
w_c
,
w_a
,
w_b
,
PostNop
{});
w
arp
_g
emm
_0
(
w_c
,
w_a
,
w_b
,
number
<
i_sub
>
{},
PostNop
{});
t_c
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_m
,
i_n
>
{},
c_warp_y_index_zeros
),
...
...
@@ -367,16 +368,17 @@ struct FusedMoeGemmPipeline_Flatmm
// clang-format off
auto
gemm_1
=
[
&
]
<
typename
PostNop
=
bool_constant
<
false
>>
(
auto
&
t_c
,
auto
&
t_a
,
auto
&
t_b
,
auto
i_access
,
PostNop
=
{})
{
auto
warp_gemm
=
Policy
::
template
GetWarpGemm1
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
warp_gemm
)
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
warp_gemm_1
)
>
;
constexpr
auto
repeat_m
=
BlockShape
::
Repeat_M1
;
// constexpr auto repeat_n = BlockShape::Repeat_N1;
constexpr
auto
repeat_k
=
BlockShape
::
Repeat_K1
;
constexpr
auto
repeat_sub
=
WarpGemm
::
get_num_of_access
();
constexpr
auto
repeat_m
=
BlockShape
::
Repeat_M0
;
// constexpr auto repeat_n = BlockShape::Repeat_N0;
constexpr
auto
repeat_k
=
BlockShape
::
Repeat_K0
;
// loop order n->m->k
constexpr
auto
i_k
=
i_access
%
repeat_k
;
constexpr
auto
i_m
=
(
i_access
/
repeat_k
)
%
repeat_m
;
constexpr
auto
i_n
=
(
i_access
/
repeat_k
)
/
repeat_m
;
constexpr
auto
i_sub
=
i_access
%
repeat_sub
;
constexpr
auto
i_k
=
(
i_access
/
repeat_sub
)
%
repeat_k
;
constexpr
auto
i_m
=
(
i_access
/
(
repeat_sub
*
repeat_k
))
%
repeat_m
;
constexpr
auto
i_n
=
(
i_access
/
(
repeat_sub
*
repeat_k
))
/
repeat_m
;
using
AWarpTensor
=
typename
WarpGemm
::
AWarpTensor
;
using
BWarpTensor
=
typename
WarpGemm
::
BWarpTensor
;
...
...
@@ -408,7 +410,7 @@ struct FusedMoeGemmPipeline_Flatmm
merge_sequences
(
sequence
<
i_m
,
i_n
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
W
arp
G
emm
{}
(
w_c
,
w_a
,
w_b
,
PostNop
{});
w
arp
_g
emm
_1
(
w_c
,
w_a
,
w_b
,
number
<
i_sub
>
{},
PostNop
{});
t_c
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_m
,
i_n
>
{},
c_warp_y_index_zeros
),
...
...
@@ -416,84 +418,72 @@ struct FusedMoeGemmPipeline_Flatmm
w_c
.
get_thread_buffer
());
};
// clang-format on
_Pragma
(
"clang diagnostic pop"
)
// this gemm pipeline is designed with assumption that issues of buffer-load/ds_read can
// be hide under mfma. In other words, issues of mfma is >= memory this is true if we
// pre-shuffle B matrix, and A matrix is relatively small we prefer use multiple mfma
// paired with 1 buffer-load B matrix, to get max throughput of buffer_load. and by
// preshuffle, we always pack to dwordx4 load, and this will already extend to multiple
// mfma but that is already consumed inside warpgemm-impl. So indeed how many extra
// mfma(that can reuse the B matrix) only affected by M repeat.
auto
pipeline_gemm0
=
[
&
]()
{
constexpr
index_t
total_loops
=
issues_gemm0
;
constexpr
index_t
mfma_per_ld
=
total_loops
/
(
issues_g
+
issues_a
+
issues_sld_a
);
// compute buffer 0
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_0
(
acc_0
,
as
[
I0
],
gs
[
I0
],
i_issue
);
if
constexpr
(
i_issue
%
mfma_per_ld
==
0
)
{
constexpr
index_t
ld_id
=
0
;
if
constexpr
(
ld_id
<
issues_g
)
{
gld_g
(
gs
[
I0
],
number
<
ld_id
>
{});
}
if
constexpr
(
ld_id
-
issues_g
<
+
issues_a
)
{
gld_a
(
a_sst_win0
,
number
<
ld_id
-
issues_g
>
{});
}
if
constexpr
(
ld_id
-
issues_g
-
issues_a
<
issues_sld_a
)
{
sld_a
(
as
[
I1
],
a_sld_win1
,
number
<
ld_id
-
issues_g
-
issues_a
>
{});
}
ld_id
++
;
}
});
move_g
();
move_a
();
block_sync_load_raw
(
issues_a
+
issues_g
);
lds_load_fence
();
// compute buffer 1
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_0
(
acc_0
,
as
[
I1
],
gs
[
I1
],
i_issue
);
if
constexpr
(
i_issue
%
mfma_per_ld
==
0
)
{
constexpr
index_t
ld_id
=
0
;
if
constexpr
(
ld_id
<
issues_g
)
{
gld_g
(
gs
[
I1
],
number
<
ld_id
>
{});
}
if
constexpr
(
ld_id
-
issues_g
<
+
issues_a
)
{
gld_a
(
a_sst_win1
,
number
<
ld_id
-
issues_g
>
{});
}
if
constexpr
(
ld_id
-
issues_g
-
issues_a
<
issues_sld_a
)
{
sld_a
(
as
[
I0
],
a_sld_win0
,
number
<
ld_id
-
issues_g
-
issues_a
>
{});
}
ld_id
++
;
}
});
move_g
();
move_a
();
block_sync_load_raw
(
issues_a
+
issues_g
);
lds_load_fence
();
};
_Pragma
(
"clang diagnostic pop"
);
// this gemm pipeline is designed with assumption that issues of buffer-load/ds_read can
// be hide under mfma. In other words, issues of mfma is >= memory this is true if we
// pre-shuffle B matrix, and A matrix is relatively small we prefer use multiple mfma
// paired with 1 buffer-load B matrix, to get max throughput of buffer_load. and by
// preshuffle, we always pack to dwordx4 load, and this will already extend to multiple
// mfma but that is already consumed inside warpgemm-impl. So indeed how many extra
// mfma(that can reuse the B matrix) only affected by M repeat.
auto
pipeline_gemm0
=
[
&
]()
{
constexpr
index_t
total_loops
=
issues_gemm0
;
constexpr
auto
sr
=
Policy
::
template
GetSequencer_0
<
Problem
>();
static_assert
(
sr
.
size
()
==
total_loops
);
constexpr
index_t
SLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
SLD_A
);
constexpr
index_t
GLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_A
);
constexpr
index_t
GLD_B
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_B
);
constexpr
auto
c_sld_a_0
=
MAKE_SC
();
constexpr
auto
c_gld_a_0
=
MAKE_SC
();
constexpr
auto
c_gld_b_0
=
MAKE_SC
();
// compute buffer 1
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_0
(
acc_0
,
as
[
I0
],
gs
[
I0
],
i_issue
);
constexpr
index_t
slot
=
sr
.
at
(
i_issue
);
if
constexpr
(
slot
&
SLD_A
)
sld_a
(
as
[
I1
],
a_sld_win1
,
number
<
NEXT_SCI
(
c_sld_a_0
,
i_issue
)
>
{});
if
constexpr
(
slot
&
GLD_A
)
gld_a
(
a_sst_win0
,
number
<
NEXT_SCI
(
c_gld_a_0
,
i_issue
)
>
{});
if
constexpr
(
slot
&
GLD_B
)
gld_g
(
gs
[
I0
],
number
<
NEXT_SCI
(
c_gld_b_0
,
i_issue
)
>
{});
});
move_g
();
move_a
();
block_sync_load_raw
(
issues_a
+
issues_g
);
lds_load_fence
();
constexpr
auto
c_sld_a_1
=
MAKE_SC
();
constexpr
auto
c_gld_a_1
=
MAKE_SC
();
constexpr
auto
c_gld_b_1
=
MAKE_SC
();
// compute buffer 1
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_0
(
acc_0
,
as
[
I1
],
gs
[
I1
],
i_issue
);
constexpr
index_t
slot
=
sr
.
at
(
i_issue
);
if
constexpr
(
slot
&
SLD_A
)
sld_a
(
as
[
I0
],
a_sld_win0
,
number
<
NEXT_SCI
(
c_sld_a_1
,
i_issue
)
>
{});
if
constexpr
(
slot
&
GLD_A
)
gld_a
(
a_sst_win1
,
number
<
NEXT_SCI
(
c_gld_a_1
,
i_issue
)
>
{});
if
constexpr
(
slot
&
GLD_B
)
gld_g
(
gs
[
I1
],
number
<
NEXT_SCI
(
c_gld_b_1
,
i_issue
)
>
{});
});
move_g
();
move_a
();
block_sync_load_raw
(
issues_a
+
issues_g
);
lds_load_fence
();
};
auto
pipeline_gemm0_tail
=
[
&
]()
{
constexpr
index_t
total_loops
=
issues_gemm0
;
constexpr
index_t
mfma_per_gld_g
=
total_loops
/
issues_g
;
// BlockShape::Repeat_M0;
// constexpr index_t mfma_per_gld_a = total_loops / issues_a;
constexpr
index_t
mfma_per_sld_a
=
total_loops
/
issues_sld_a
;
//
constexpr index_t mfma_per_sld_a = total_loops / issues_sld_a;
// compute buffer 0
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
...
...
@@ -515,7 +505,7 @@ struct FusedMoeGemmPipeline_Flatmm
});
// if cycle_mfma>gld_a sync here
block_sync_load_raw
(
issues_g
);
sld_a
(
as
[
I1
],
a_sld_win1
,
NEG1
{}
);
sld_a
(
as
[
I1
],
a_sld_win1
,
NEG1
);
// compute buffer 1
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
View file @
f09dc1f3
...
...
@@ -609,11 +609,45 @@ struct FusedMoeGemmPipelineFlatmmPolicy
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSequencer_0
()
{
// this function return seq<...> used to identify gld/sld/valu... inside mfma sequence
// the purpose is to hide thoes instructions under mfma
// every value inside seq<...> is a mask, indicating a specific operation
using
S_
=
typename
Problem
::
BlockShape
;
constexpr
index_t
SLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
SLD_A
);
constexpr
index_t
GLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_A
);
constexpr
index_t
GLD_B
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_B
);
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
16
&&
S_
::
Block_M0
==
32
&&
S_
::
Block_N0
==
512
&&
S_
::
Block_K0
==
128
&&
S_
::
Block_N1
==
128
)
{
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr
auto
seq_all
=
// 0 1 2 3 4 5 6 7
sequence
<
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
// 0
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
// 1
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
// 2
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
// 3
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 4
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 5
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 6
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
>
{};
// 7
return
seq_all
;
// clang-format on
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemm1
()
{
using
S_
=
typename
Problem
::
BlockShape
;
constexpr
auto
wg_ctrl
=
WGAttrCtlEnum
::
Raw_v
v
a
;
constexpr
auto
wg_ctrl
=
WGAttrCtlEnum
::
Raw_va
v
;
// TODO: ugly
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
bf16_t
>
&&
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
View file @
f09dc1f3
...
...
@@ -33,4 +33,15 @@ struct FusedMoeGemmTraits
static
constexpr
bool
PadHiddenSize
=
PadHiddenSize_
;
static
constexpr
bool
PadIntermediateSize
=
PadIntermediateSize_
;
};
// Note: this need to be a bit mask
enum
class
FusedMoeGemmPipelineSequencerEnum
{
SLD_A
=
1
<<
0
,
// shared load a
SLD_B
=
1
<<
1
,
GLD_A
=
1
<<
2
,
// global load a
GLD_B
=
1
<<
3
,
SST_A
=
1
<<
4
,
// shared store a
SST_B
=
1
<<
5
,
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
View file @
f09dc1f3
...
...
@@ -25,6 +25,8 @@ struct WarpGemmAtrributeMfma
static
constexpr
index_t
kN
=
Impl
::
kN
;
static
constexpr
index_t
kK
=
Impl
::
kK
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
1
;
}
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
>>
,
...
...
@@ -88,6 +90,8 @@ struct WarpGemmAtrributeMfmaIterateK
static
constexpr
index_t
kN
=
Impl
::
kN
;
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
...
...
@@ -197,6 +201,8 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
static
constexpr
index_t
kN
=
Impl
::
kM
;
static
constexpr
index_t
kK
=
Impl
::
kK
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
1
;
}
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
>>
,
...
...
@@ -258,6 +264,8 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
static
constexpr
index_t
kN
=
Impl
::
kM
;
static
constexpr
index_t
kK
=
Impl
::
kK
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
1
;
}
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
>>
,
...
...
@@ -326,6 +334,8 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
static
constexpr
index_t
kN
=
Impl
::
kM
;
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
...
...
@@ -439,6 +449,8 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
static
constexpr
index_t
SFactor
=
SFactor_
;
// group how many CM1 together
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kBNLane
>
,
sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
...
...
@@ -576,6 +588,8 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
static
constexpr
index_t
SFactor
=
SFactor_
;
// group how many CM1 together
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
kKIter
;
}
using
AWarpDstrEncoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Impl
::
kAMLane
/
(
Impl
::
kCMLane
*
SFactor
*
Impl
::
kCM1PerLane
),
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
View file @
f09dc1f3
...
...
@@ -24,7 +24,7 @@ enum class WGAttrCtlEnum
#define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_) \
if constexpr(post_nop_) \
{ \
asm volatile(mfma_ " %0, %1, %2, %3
\n"
\
asm volatile(mfma_ " %0, %1, %2, %3
; yyy\n"
\
"s_nop 3" \
: dmod_(c_vec) \
: amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
View file @
f09dc1f3
...
...
@@ -31,6 +31,11 @@ struct WarpGemmImpl
using
BWarpTensor
=
static_distributed_tensor
<
BDataType
,
BWarpDstr
>
;
using
CWarpTensor
=
static_distributed_tensor
<
CDataType
,
CWarpDstr
>
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_access
()
{
return
WarpGemmAttribute_
::
get_num_of_access
();
}
template
<
typename
CTensor
,
typename
ATensor
,
typename
BTensor
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CTensor
&
c
,
const
ATensor
&
a
,
const
BTensor
&
b
,
bool_constant
<
post_nop_
>
=
{})
const
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment