Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
97efebdb
Commit
97efebdb
authored
Feb 02, 2025
by
Qianfeng Zhang
Browse files
Special treatment for hdim-96 to save vgprs in qr_ks_vs_async pipeline
parent
a94ac4bb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
139 additions
and
95 deletions
+139
-95
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+139
-95
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
97efebdb
...
...
@@ -76,25 +76,27 @@ struct BlockFmhaPipelineQRKSVSAsync
return
Problem
::
kBlockPerCu
;
else
{
if
constexpr
(
kQKHeaddim
<
=
32
)
if
constexpr
(
kQKHeaddim
=
=
32
)
{
return
2
;
}
else
if
constexpr
(
kQKHeaddim
<
=
64
)
else
if
constexpr
(
kQKHeaddim
=
=
64
)
{
return
2
;
}
else
if
constexpr
(
kQKHeaddim
<
=
128
)
else
if
constexpr
(
kQKHeaddim
==
96
||
kQKHeaddim
=
=
128
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
else
return
1
;
return
2
;
}
else
if
constexpr
(
kQKHeaddim
<
=
256
)
else
if
constexpr
(
kQKHeaddim
=
=
256
)
{
return
1
;
}
else
return
1
;
}
}();
...
...
@@ -170,7 +172,6 @@ struct BlockFmhaPipelineQRKSVSAsync
constexpr
auto
NumVLdsBuffers
=
Policy
::
template
GetNumVLdsBuffers
<
Problem
>();
static_assert
(
NumKLdsBuffers
>=
2
);
static_assert
(
NumVLdsBuffers
>=
2
);
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
...
...
@@ -269,7 +270,13 @@ struct BlockFmhaPipelineQRKSVSAsync
using
k_tile_type
=
decltype
(
load_tile
(
k_dram_window
));
statically_indexed_array
<
k_tile_type
,
k0_loops
>
k_tiles
;
auto
k_tiles
=
[
&
]()
{
// for hdim-96 and hdim-160, try to save vgprs
if
constexpr
(
kQKHeaddim
<
kSubQKHeaddim
)
return
statically_indexed_array
<
k_tile_type
,
2
>
{};
else
return
statically_indexed_array
<
k_tile_type
,
k0_loops
>
{};
}();
k_tiles
[
I0
]
=
load_tile
(
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
...
...
@@ -296,121 +303,158 @@ struct BlockFmhaPipelineQRKSVSAsync
do
{
if
(
i_total_loops
==
0
)
// executed by fist iteration
if
constexpr
(
kQKHeaddim
==
kSubQKHeaddim
)
{
if
(
num
_total_loop
>
1
)
// there are multiple
iteration
s
if
(
i
_total_loop
s
==
0
)
// executed by fist
iteration
{
auto
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
0
>
{},
sequence
<
kN0
,
kK0
>
{});
store_tile
(
k_lds_window_tmp
,
k_tiles
[
I0
]);
if
(
num_total_loop
>
1
)
// there are multiple iterations
{
auto
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
0
>
{},
sequence
<
kN0
,
kK0
>
{});
store_tile
(
k_lds_window_tmp
,
k_tiles
[
I0
]);
clear_tile
(
s_acc
);
// initialize C
clear_tile
(
s_acc
);
// initialize C
static_for
<
0
,
k0_loops
-
1
,
1
>
{}([
&
](
auto
i_k0
)
{
k_tiles
[
number
<
i_k0
+
1
>
{}]
=
load_tile
(
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
static_for
<
0
,
k0_loops
-
1
,
1
>
{}([
&
](
auto
i_k0
)
{
k_tiles
[
number
<
i_k0
+
1
>
{}]
=
load_tile
(
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
block_sync_lds
();
// execute current unroll of gemm_0
gemm_0
(
s_acc
,
q_tiles
[
number
<
i_k0
>
{}],
k_lds_window_tmp
);
block_sync_lds
();
// execute current unroll of gemm_0
gemm_0
(
s_acc
,
q_tiles
[
number
<
i_k0
>
{}],
k_lds_window_tmp
);
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
((
i_k0
+
1
)
%
NumKLdsBuffers
)
*
kN0
,
0
>
{},
sequence
<
(((
i_k0
+
1
)
%
NumKLdsBuffers
)
+
1
)
*
kN0
,
kK0
>
{});
store_tile
(
k_lds_window_tmp
,
k_tiles
[
number
<
i_k0
+
1
>
{}]);
});
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
((
i_k0
+
1
)
%
NumKLdsBuffers
)
*
kN0
,
0
>
{},
sequence
<
(((
i_k0
+
1
)
%
NumKLdsBuffers
)
+
1
)
*
kN0
,
kK0
>
{});
store_tile
(
k_lds_window_tmp
,
k_tiles
[
number
<
i_k0
+
1
>
{}]);
});
move_tile_window
(
k_dram_window
,
{
kN0
,
-
k0_loops
*
kK0
});
move_tile_window
(
k_dram_window
,
{
kN0
,
-
k0_loops
*
kK0
});
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
k_tiles
[
number
<
i_k0
>
{}]
=
load_tile
(
k_dram_window
);
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
k_tiles
[
number
<
i_k0
>
{}]
=
load_tile
(
k_dram_window
);
if
constexpr
(
i_k0
<
k0_loops
-
1
)
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
});
if
constexpr
(
i_k0
<
k0_loops
-
1
)
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
});
move_tile_window
(
k_dram_window
,
{
0
,
-
(
k0_loops
-
1
)
*
kK0
});
move_tile_window
(
k_dram_window
,
{
0
,
-
(
k0_loops
-
1
)
*
kK0
});
block_sync_lds
();
// execute last unroll of gemm_0
gemm_0
(
s_acc
,
q_tiles
[
number
<
k0_loops
-
1
>
{}],
k_lds_window_tmp
);
block_sync_lds
();
// execute last unroll of gemm_0
gemm_0
(
s_acc
,
q_tiles
[
number
<
k0_loops
-
1
>
{}],
k_lds_window_tmp
);
}
else
// there is only single iteration
{
auto
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
0
>
{},
sequence
<
kN0
,
kK0
>
{});
store_tile
(
k_lds_window_tmp
,
k_tiles
[
I0
]);
clear_tile
(
s_acc
);
// initialize C
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
if
constexpr
(
i_k0
<
k0_loops
-
1
)
{
k_tiles
[
number
<
i_k0
+
1
>
{}]
=
load_tile
(
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
};
block_sync_lds
();
// execute current unroll of gemm_0
gemm_0
(
s_acc
,
q_tiles
[
number
<
i_k0
>
{}],
k_lds_window_tmp
);
if
constexpr
(
i_k0
<
k0_loops
-
1
)
{
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
((
i_k0
+
1
)
%
NumKLdsBuffers
)
*
kN0
,
0
>
{},
sequence
<
(((
i_k0
+
1
)
%
NumKLdsBuffers
)
+
1
)
*
kN0
,
kK0
>
{});
store_tile
(
k_lds_window_tmp
,
k_tiles
[
number
<
i_k0
+
1
>
{}]);
};
});
// move_tile_window(k_dram_window, {0, -k0_loops * kK0});
}
}
else
//
there is only single
iteration
else
//
executed by intermediate and last
iteration
{
auto
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
0
>
{},
sequence
<
kN0
,
kK0
>
{});
store_tile
(
k_lds_window_tmp
,
k_tiles
[
I0
]
);
if
(
i_total_loops
<
num_total_loop
-
1
)
// intermediate iteration
{
move_tile_window
(
k_dram_window
,
{
kN0
,
0
}
);
clear_tile
(
s_acc
);
// initialize C
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
auto
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
(
i_k0
%
NumKLdsBuffers
)
*
kN0
,
0
>
{},
sequence
<
((
i_k0
%
NumKLdsBuffers
)
+
1
)
*
kN0
,
kK0
>
{});
store_tile
(
k_lds_window_tmp
,
k_tiles
[
number
<
i_k0
>
{}]);
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
if
constexpr
(
i_k0
<
k0_loops
-
1
)
{
k_tiles
[
number
<
i_k0
+
1
>
{}]
=
load_tile
(
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
};
k_tiles
[
number
<
i_k0
>
{}]
=
load_tile
(
k_dram_window
);
if
constexpr
(
i_k0
<
k0_loops
-
1
)
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
block_sync_lds
();
// execute current unroll of gemm_0
gemm_0
(
s_acc
,
q_tiles
[
number
<
i_k0
>
{}],
k_lds_window_tmp
);
if
constexpr
(
i_k0
==
0
)
clear_tile
(
s_acc
);
if
constexpr
(
i_k0
<
k0_loops
-
1
)
{
k_lds_window_tmp
=
get_slice_tile
(
block_sync_lds
();
gemm_0
(
s_acc
,
q_tiles
[
number
<
i_k0
>
{}],
k_lds_window_tmp
);
});
move_tile_window
(
k_dram_window
,
{
0
,
-
(
k0_loops
-
1
)
*
kK0
});
}
else
// last iteration
{
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
auto
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
((
i_k0
+
1
)
%
NumKLdsBuffers
)
*
kN0
,
0
>
{},
sequence
<
(((
i_k0
+
1
)
%
NumKLdsBuffers
)
+
1
)
*
kN0
,
kK0
>
{});
store_tile
(
k_lds_window_tmp
,
k_tiles
[
number
<
i_k0
+
1
>
{}]);
};
});
sequence
<
(
i_k0
%
NumKLdsBuffers
)
*
kN0
,
0
>
{},
sequence
<
((
i_k0
%
NumKLdsBuffers
)
+
1
)
*
kN0
,
kK0
>
{});
store_tile
(
k_lds_window_tmp
,
k_tiles
[
number
<
i_k0
>
{}]);
// move_tile_window(k_dram_window, {0, -k0_loops * kK0});
}
if
constexpr
(
i_k0
==
0
)
clear_tile
(
s_acc
);
block_sync_lds
();
gemm_0
(
s_acc
,
q_tiles
[
number
<
i_k0
>
{}],
k_lds_window_tmp
);
});
};
};
}
else
// executed by intermediate and last iteration
else
{
if
(
i_total_loops
<
num_total_loop
-
1
)
// intermediate iteration
{
move_tile_window
(
k_dram_window
,
{
kN0
,
0
}
);
auto
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
0
>
{},
sequence
<
kN0
,
kK0
>
{});
store_tile
(
k_lds_window_tmp
,
k_tiles
[
I0
]
);
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
auto
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
(
i_k0
%
NumKLdsBuffers
)
*
kN0
,
0
>
{},
sequence
<
((
i_k0
%
NumKLdsBuffers
)
+
1
)
*
kN0
,
kK0
>
{});
store_tile
(
k_lds_window_tmp
,
k_tiles
[
number
<
i_k0
>
{}]);
clear_tile
(
s_acc
);
// initialize C
k_tiles
[
number
<
i_k0
>
{}]
=
load_tile
(
k_dram_window
);
if
constexpr
(
i_k0
<
k0_loops
-
1
)
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
if
constexpr
(
i_k0
<
k0_loops
-
1
)
{
k_tiles
[
number
<
(
i_k0
+
1
)
%
2
>
{}]
=
load_tile
(
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
};
if
constexpr
(
i_k0
==
0
)
clear_tile
(
s_acc
);
block_sync_lds
();
// execute current unroll of gemm_0
gemm_0
(
s_acc
,
q_tiles
[
number
<
i_k0
>
{}],
k_lds_window_tmp
);
block_sync_lds
();
// execute last unroll of gemm_0
gemm_0
(
s_acc
,
q_tiles
[
number
<
i_k0
>
{}],
k_lds_window_tmp
);
});
if
constexpr
(
i_k0
<
k0_loops
-
1
)
{
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
((
i_k0
+
1
)
%
NumKLdsBuffers
)
*
kN0
,
0
>
{},
sequence
<
(((
i_k0
+
1
)
%
NumKLdsBuffers
)
+
1
)
*
kN0
,
kK0
>
{});
store_tile
(
k_lds_window_tmp
,
k_tiles
[
number
<
(
i_k0
+
1
)
%
2
>
{}]);
};
});
move_tile_window
(
k_dram_window
,
{
0
,
-
(
k0_loops
-
1
)
*
kK0
});
}
else
// last iteration
if
(
i_total_loops
<
num_total_loop
-
1
)
{
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
auto
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
(
i_k0
%
NumKLdsBuffers
)
*
kN0
,
0
>
{},
sequence
<
((
i_k0
%
NumKLdsBuffers
)
+
1
)
*
kN0
,
kK0
>
{});
store_tile
(
k_lds_window_tmp
,
k_tiles
[
number
<
i_k0
>
{}]);
if
constexpr
(
i_k0
==
0
)
clear_tile
(
s_acc
);
block_sync_lds
();
// execute last unroll of gemm_0
gemm_0
(
s_acc
,
q_tiles
[
number
<
i_k0
>
{}],
k_lds_window_tmp
);
});
move_tile_window
(
k_dram_window
,
{
kN0
,
-
k0_loops
*
kK0
});
k_tiles
[
I0
]
=
load_tile
(
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment