Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3bb718ad
Commit
3bb718ad
authored
Nov 06, 2024
by
valarLip
Browse files
update pipeline_gemm0
parent
c6c3c142
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
107 additions
and
47 deletions
+107
-47
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+5
-0
include/ck_tile/core/arch/arch.hpp
include/ck_tile/core/arch/arch.hpp
+18
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp
.../ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp
+84
-47
No files found.
include/ck_tile/core/arch/amd_buffer_addressing.hpp
View file @
3bb718ad
...
@@ -640,6 +640,11 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
...
@@ -640,6 +640,11 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
}
CK_TILE_DEVICE
void
lds_load_fence
(
index_t
cnt
=
0
)
{
asm
volatile
(
"s_waitcnt lgkmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
template
<
typename
scalar_type
,
index_t
N
,
bool
pre_nop
=
false
>
template
<
typename
scalar_type
,
index_t
N
,
bool
pre_nop
=
false
>
struct
buffer_atomic_add_if
;
struct
buffer_atomic_add_if
;
...
...
include/ck_tile/core/arch/arch.hpp
View file @
3bb718ad
...
@@ -73,6 +73,24 @@ CK_TILE_DEVICE void block_sync_lds()
...
@@ -73,6 +73,24 @@ CK_TILE_DEVICE void block_sync_lds()
#endif
#endif
}
}
CK_TILE_DEVICE
void
block_sync_load_raw
(
index_t
cnt
=
0
)
{
#ifdef __gfx12__
asm
volatile
(
"s_wait_loadcnt %0
\n
"
"s_barrier_signal -1
\n
"
"s_barrier_wait -1"
:
:
"n"
(
cnt
)
:
"memory"
);
#else
asm
volatile
(
"s_waitcnt vmcnt(%0)
\n
"
"s_barrier"
:
:
"n"
(
cnt
)
:
"memory"
);
#endif
}
CK_TILE_DEVICE
void
block_sync_lds_direct_load
()
CK_TILE_DEVICE
void
block_sync_lds_direct_load
()
{
{
asm
volatile
(
"\
asm
volatile
(
"\
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp
View file @
3bb718ad
...
@@ -260,9 +260,9 @@ struct FusedMoeGemmPipeline_Flatmm
...
@@ -260,9 +260,9 @@ struct FusedMoeGemmPipeline_Flatmm
{
{
async_load_tile_raw
(
a_store_
,
a_win
,
i_access
,
PreNop
{});
async_load_tile_raw
(
a_store_
,
a_win
,
i_access
,
PreNop
{});
};
};
//
auto move_a = [&]() {
auto
move_a
=
[
&
]()
{
//
move_tile_window(a_win, {number<0>{}, number<BlockShape::Block_K0>{}});
move_tile_window
(
a_win
,
{
number
<
0
>
{},
number
<
BlockShape
::
Block_K0
>
{}});
//
};
};
auto
sld_a
=
[
&
](
auto
&
a_
,
auto
&
win_
,
auto
i_access
)
{
auto
sld_a
=
[
&
](
auto
&
a_
,
auto
&
win_
,
auto
i_access
)
{
load_tile_raw
(
a_
,
win_
,
i_access
);
load_tile_raw
(
a_
,
win_
,
i_access
);
};
};
...
@@ -284,11 +284,11 @@ struct FusedMoeGemmPipeline_Flatmm
...
@@ -284,11 +284,11 @@ struct FusedMoeGemmPipeline_Flatmm
}
}
load_tile_raw
(
g_
,
g_win
,
i_access
,
FALSE
,
PreNop
{});
load_tile_raw
(
g_
,
g_win
,
i_access
,
FALSE
,
PreNop
{});
};
};
//
auto move_g =
auto
move_g
=
//
[&]() {
[
&
]()
{
//
move_tile_window(g_win,
move_tile_window
(
g_win
,
//
{number<0>{}, number<BlockShape::Block_Kr0>{}, number<0>{}});
{
number
<
0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
0
>
{}});
//
};
};
statically_indexed_array
<
d_thread_type
,
2
>
ds
;
statically_indexed_array
<
d_thread_type
,
2
>
ds
;
auto
gld_d
=
[
&
]
<
typename
PreNop
=
bool_constant
<
false
>>
(
auto
gld_d
=
[
&
]
<
typename
PreNop
=
bool_constant
<
false
>>
(
...
@@ -296,10 +296,10 @@ struct FusedMoeGemmPipeline_Flatmm
...
@@ -296,10 +296,10 @@ struct FusedMoeGemmPipeline_Flatmm
{
{
load_tile_raw
(
d_
,
d_win
,
i_access
,
FALSE
,
PreNop
{});
load_tile_raw
(
d_
,
d_win
,
i_access
,
FALSE
,
PreNop
{});
};
};
//
auto move_d = [&]() {
auto
move_d
=
[
&
]()
{
//
// d move along gemm-n
// d move along gemm-n
//
move_tile_window(d_win, {number<BlockShape::Block_N1>{}, number<0>{}});
move_tile_window
(
d_win
,
{
number
<
BlockShape
::
Block_N1
>
{},
number
<
0
>
{}});
//
};
};
auto
atomic_add_o
=
[
&
]
<
typename
PreNop
=
bool_constant
<
false
>>
(
auto
atomic_add_o
=
[
&
]
<
typename
PreNop
=
bool_constant
<
false
>>
(
auto
&
o_
,
auto
i_access
,
PreNop
=
{})
auto
&
o_
,
auto
i_access
,
PreNop
=
{})
...
@@ -427,53 +427,66 @@ struct FusedMoeGemmPipeline_Flatmm
...
@@ -427,53 +427,66 @@ struct FusedMoeGemmPipeline_Flatmm
// mfma(that can reuse the B matrix) only affected by M repeat.
// mfma(that can reuse the B matrix) only affected by M repeat.
auto
pipeline_gemm0
=
[
&
]()
{
auto
pipeline_gemm0
=
[
&
]()
{
constexpr
index_t
total_loops
=
issues_gemm0
;
constexpr
index_t
total_loops
=
issues_gemm0
;
constexpr
index_t
mfma_per_gld_g
=
total_loops
/
issues_g
;
// BlockShape::Repeat_M0;
constexpr
index_t
mfma_per_ld
=
total_loops
/
(
issues_g
+
issues_a
+
issues_sld_a
);
constexpr
index_t
mfma_per_gld_a
=
total_loops
/
issues_a
;
constexpr
index_t
mfma_per_sld_a
=
total_loops
/
issues_sld_a
;
// compute buffer 0
// compute buffer 0
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_0
(
acc_0
,
as
[
I0
],
gs
[
I0
],
i_issue
);
gemm_0
(
acc_0
,
as
[
I0
],
gs
[
I0
],
i_issue
);
if
constexpr
(
i_issue
%
mfma_per_gld_g
==
0
)
{
gld_g
(
gs
[
I1
],
number
<
i_issue
/
mfma_per_gld_g
>
{});
move_g
();
}
if
constexpr
(
i_issue
%
mfma_per_
g
ld
_a
==
0
)
if
constexpr
(
i_issue
%
mfma_per_ld
==
0
)
{
{
gld_a
(
a_sst_win0
,
number
<
i_issue
/
mfma_per_gld_a
>
{});
constexpr
index_t
ld_id
=
0
;
move_a
();
if
constexpr
(
ld_id
<
issues_g
)
{
gld_g
(
gs
[
I0
],
number
<
ld_id
>
{});
}
if
constexpr
(
ld_id
-
issues_g
<
+
issues_a
)
{
gld_a
(
a_sst_win0
,
number
<
ld_id
-
issues_g
>
{});
}
if
constexpr
(
ld_id
-
issues_g
-
issues_a
<
issues_sld_a
)
{
sld_a
(
as
[
I1
],
a_sld_win1
,
number
<
ld_id
-
issues_g
-
issues_a
>
{});
}
ld_id
++
;
}
}
if
constexpr
(
i_issue
%
mfma_per_sld_a
==
0
)
{
block_sync_lds
();
sld_a
(
as
[
I1
],
a_sld_win1
,
number
<
i_issue
/
mfma_per_sld_a
>
{});
}
});
});
move_g
();
move_a
();
block_sync_load_raw
(
issues_a
+
issues_g
);
lds_load_fence
();
// compute buffer 1
// compute buffer 1
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_0
(
acc_0
,
as
[
I1
],
gs
[
I1
],
i_issue
);
gemm_0
(
acc_0
,
as
[
I1
],
gs
[
I1
],
i_issue
);
if
constexpr
(
i_issue
%
mfma_per_gld_g
==
0
)
{
gld_g
(
gs
[
I0
],
number
<
i_issue
/
mfma_per_gld_g
>
{});
move_g
();
}
if
constexpr
(
i_issue
%
mfma_per_gld_a
==
0
)
{
gld_a
(
a_sst_win1
,
number
<
i_issue
/
mfma_per_gld_a
>
{});
move_a
();
}
if
constexpr
(
i_issue
%
mfma_per_
s
ld
_a
==
0
)
if
constexpr
(
i_issue
%
mfma_per_ld
==
0
)
{
{
block_sync_lds
();
constexpr
index_t
ld_id
=
0
;
sld_a
(
as
[
I0
],
a_sld_win0
,
number
<
i_issue
/
mfma_per_sld_a
>
{});
if
constexpr
(
ld_id
<
issues_g
)
{
gld_g
(
gs
[
I1
],
number
<
ld_id
>
{});
}
if
constexpr
(
ld_id
-
issues_g
<
+
issues_a
)
{
gld_a
(
a_sst_win1
,
number
<
ld_id
-
issues_g
>
{});
}
if
constexpr
(
ld_id
-
issues_g
-
issues_a
<
issues_sld_a
)
{
sld_a
(
as
[
I0
],
a_sld_win0
,
number
<
ld_id
-
issues_g
-
issues_a
>
{});
}
ld_id
++
;
}
}
});
});
move_g
();
move_a
();
block_sync_load_raw
(
issues_a
+
issues_g
);
lds_load_fence
();
};
};
auto
pipeline_gemm0_tail
=
[
&
]()
{
auto
pipeline_gemm0_tail
=
[
&
]()
{
...
@@ -486,14 +499,23 @@ struct FusedMoeGemmPipeline_Flatmm
...
@@ -486,14 +499,23 @@ struct FusedMoeGemmPipeline_Flatmm
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_0
(
acc_0
,
as
[
I0
],
gs
[
I0
],
i_issue
);
gemm_0
(
acc_0
,
as
[
I0
],
gs
[
I0
],
i_issue
);
if
constexpr
(
i_issue
%
mfma_per_gld_g
==
0
)
if
constexpr
(
i_issue
%
mfma_per_gld_g
==
0
)
{
gld_g
(
gs
[
I1
],
number
<
i_issue
/
mfma_per_gld_g
>
{});
gld_g
(
gs
[
I1
],
number
<
i_issue
/
mfma_per_gld_g
>
{});
move_g
();
}
// if constexpr (i_issue % mfma_per_gld_a == 0)
// if constexpr (i_issue % mfma_per_gld_a == 0)
// gld_a(a_sst_win0, number<i_issue / mfma_per_gld_a>{});
// gld_a(a_sst_win0, number<i_issue / mfma_per_gld_a>{});
if
constexpr
(
i_issue
%
mfma_per_sld_a
==
0
)
// if constexpr(i_issue % mfma_per_sld_a == 0)
sld_a
(
as
[
I1
],
a_sld_win1
,
number
<
i_issue
/
mfma_per_sld_a
>
{});
// {
// block_sync_load_raw(a_sst_win0.get_num_of_access());
// sld_a(as[I1], a_sld_win1, number<i_issue / mfma_per_sld_a>{});
// }
});
});
// if cycle_mfma>gld_a sync here
block_sync_load_raw
(
issues_g
);
sld_a
(
as
[
I1
],
a_sld_win1
,
NEG1
{});
// compute buffer 1
// compute buffer 1
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
...
@@ -523,7 +545,10 @@ struct FusedMoeGemmPipeline_Flatmm
...
@@ -523,7 +545,10 @@ struct FusedMoeGemmPipeline_Flatmm
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_1
(
acc_1s
[
I1
],
y
,
ds
[
I1
],
i_issue
);
gemm_1
(
acc_1s
[
I1
],
y
,
ds
[
I1
],
i_issue
);
if
constexpr
(
i_issue
%
mfma_per_gld_d
==
0
)
if
constexpr
(
i_issue
%
mfma_per_gld_d
==
0
)
{
gld_d
(
ds
[
I0
],
number
<
i_issue
/
mfma_per_gld_d
>
{});
gld_d
(
ds
[
I0
],
number
<
i_issue
/
mfma_per_gld_d
>
{});
move_d
();
}
if
constexpr
(
i_issue
%
mfma_per_atm_o
==
0
)
if
constexpr
(
i_issue
%
mfma_per_atm_o
==
0
)
{
{
...
@@ -536,7 +561,10 @@ struct FusedMoeGemmPipeline_Flatmm
...
@@ -536,7 +561,10 @@ struct FusedMoeGemmPipeline_Flatmm
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_1
(
acc_1s
[
I0
],
y
,
ds
[
I0
],
i_issue
);
gemm_1
(
acc_1s
[
I0
],
y
,
ds
[
I0
],
i_issue
);
if
constexpr
(
i_issue
%
mfma_per_gld_d
==
0
)
if
constexpr
(
i_issue
%
mfma_per_gld_d
==
0
)
{
gld_d
(
ds
[
I1
],
number
<
i_issue
/
mfma_per_gld_d
>
{});
gld_d
(
ds
[
I1
],
number
<
i_issue
/
mfma_per_gld_d
>
{});
move_d
();
}
if
constexpr
(
i_issue
%
mfma_per_atm_o
==
0
)
if
constexpr
(
i_issue
%
mfma_per_atm_o
==
0
)
{
{
...
@@ -553,7 +581,10 @@ struct FusedMoeGemmPipeline_Flatmm
...
@@ -553,7 +581,10 @@ struct FusedMoeGemmPipeline_Flatmm
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_1
(
acc_1s
[
I0
],
y
,
ds
[
I0
],
i_issue
);
gemm_1
(
acc_1s
[
I0
],
y
,
ds
[
I0
],
i_issue
);
if
constexpr
(
i_issue
%
mfma_per_gld_d
==
0
)
if
constexpr
(
i_issue
%
mfma_per_gld_d
==
0
)
{
gld_d
(
ds
[
I1
],
number
<
i_issue
/
mfma_per_gld_d
>
{});
gld_d
(
ds
[
I1
],
number
<
i_issue
/
mfma_per_gld_d
>
{});
move_d
();
}
});
});
};
};
auto
pipeline_gemm1_tail
=
[
&
]()
{
auto
pipeline_gemm1_tail
=
[
&
]()
{
...
@@ -564,7 +595,10 @@ struct FusedMoeGemmPipeline_Flatmm
...
@@ -564,7 +595,10 @@ struct FusedMoeGemmPipeline_Flatmm
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_1
(
acc_1s
[
I1
],
y
,
ds
[
I1
],
i_issue
);
gemm_1
(
acc_1s
[
I1
],
y
,
ds
[
I1
],
i_issue
);
if
constexpr
(
i_issue
%
mfma_per_gld_d
==
0
)
if
constexpr
(
i_issue
%
mfma_per_gld_d
==
0
)
{
gld_d
(
ds
[
I0
],
number
<
i_issue
/
mfma_per_gld_d
>
{});
gld_d
(
ds
[
I0
],
number
<
i_issue
/
mfma_per_gld_d
>
{});
move_d
();
}
if
constexpr
(
i_issue
%
mfma_per_atm_o
==
0
)
if
constexpr
(
i_issue
%
mfma_per_atm_o
==
0
)
{
{
...
@@ -586,10 +620,13 @@ struct FusedMoeGemmPipeline_Flatmm
...
@@ -586,10 +620,13 @@ struct FusedMoeGemmPipeline_Flatmm
move_g
();
move_g
();
clear_tile
(
acc_0
);
clear_tile
(
acc_0
);
async_load_fence_raw
(
g_win
.
get_num_of_access
());
// preload for next round
s
ld_a
(
a
s
[
I0
],
a_sld
_win
0
,
NEG1
);
g
ld_a
(
a
_sst
_win
1
,
NEG1
);
gld_
a
(
a_sst_win1
,
NEG1
);
gld_
g
(
gs
[
I1
]
,
NEG1
);
// make sure a,g loaded
block_sync_load_raw
(
issues_a
+
issues_g
);
lds_load_fence
();
// we manually unroll double buffer inside hot loop
// we manually unroll double buffer inside hot loop
const
index_t
iters_0
=
(
num_blocks_k0
-
2
)
/
2
;
const
index_t
iters_0
=
(
num_blocks_k0
-
2
)
/
2
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment