Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
199f7f71
Commit
199f7f71
authored
Sep 01, 2024
by
carlushuang
Browse files
modify moe
parent
33ceea62
Changes
22
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
7 deletions
+14
-7
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+5
-5
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
+9
-2
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
199f7f71
...
@@ -282,7 +282,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -282,7 +282,7 @@ struct BlockFmhaPipelineQRKSVSAsync
store_tile
(
lse_dram_window_tmp
,
tile_elementwise_in
(
lse_element_func
,
lse
));
store_tile
(
lse_dram_window_tmp
,
tile_elementwise_in
(
lse_element_func
,
lse
));
}
}
buffer_load_fence
(
0
);
// rocm-6.1, if whole tile is masked out, need to fence(0)
buffer_load_fence
_raw
(
0
);
// rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it
// Note: here occ are all cleard, return it
...
@@ -334,7 +334,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -334,7 +334,7 @@ struct BlockFmhaPipelineQRKSVSAsync
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_fence
(
k_dram_window
.
get_num_access
(),
q
.
get_thread_buffer
());
buffer_load_fence
_raw
(
k_dram_window
.
get_num_access
(),
q
.
get_thread_buffer
());
(
void
)
q_element_func
;
// ??? rocm-6.x if use q element func will have scratch on hdim=64/32
(
void
)
q_element_func
;
// ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
...
@@ -359,7 +359,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -359,7 +359,7 @@ struct BlockFmhaPipelineQRKSVSAsync
if
constexpr
(
i_k0
<
k0_loops
-
1
)
if
constexpr
(
i_k0
<
k0_loops
-
1
)
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
async_load_fence
(
k_dram_window
.
get_num_access
());
async_load_fence
_raw
(
k_dram_window
.
get_num_access
());
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
gemm_0
(
s_acc
,
gemm_0
(
s_acc
,
...
@@ -381,7 +381,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -381,7 +381,7 @@ struct BlockFmhaPipelineQRKSVSAsync
if
constexpr
(
k0_loops
<=
2
)
if
constexpr
(
k0_loops
<=
2
)
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
async_load_fence
();
async_load_fence
_raw
();
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_s_barrier
();
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
View file @
199f7f71
...
@@ -31,8 +31,12 @@ struct WarpGemmImpl
...
@@ -31,8 +31,12 @@ struct WarpGemmImpl
using
BWarpTensor
=
static_distributed_tensor
<
BDataType
,
BWarpDstr
>
;
using
BWarpTensor
=
static_distributed_tensor
<
BDataType
,
BWarpDstr
>
;
using
CWarpTensor
=
static_distributed_tensor
<
CDataType
,
CWarpDstr
>
;
using
CWarpTensor
=
static_distributed_tensor
<
CDataType
,
CWarpDstr
>
;
CK_TILE_DEVICE
void
operator
()(
CWarpTensor
&
c
,
const
AWarpTensor
&
a
,
const
BWarpTensor
&
b
)
const
template
<
typename
CTensor
,
typename
ATensor
,
typename
BTensor
>
CK_TILE_DEVICE
void
operator
()(
CTensor
&
c
,
const
ATensor
&
a
,
const
BTensor
&
b
)
const
{
{
static_assert
(
detail
::
is_similiar_distributed_tensor_v
<
CTensor
,
CWarpTensor
>
&&
detail
::
is_similiar_distributed_tensor_v
<
ATensor
,
AWarpTensor
>
&&
detail
::
is_similiar_distributed_tensor_v
<
BTensor
,
BWarpTensor
>
);
using
AVec
=
ext_vector_t
<
ADataType
,
AWarpTensor
::
get_thread_buffer_size
()
>
;
using
AVec
=
ext_vector_t
<
ADataType
,
AWarpTensor
::
get_thread_buffer_size
()
>
;
using
BVec
=
ext_vector_t
<
BDataType
,
BWarpTensor
::
get_thread_buffer_size
()
>
;
using
BVec
=
ext_vector_t
<
BDataType
,
BWarpTensor
::
get_thread_buffer_size
()
>
;
using
CVec
=
ext_vector_t
<
CDataType
,
CWarpTensor
::
get_thread_buffer_size
()
>
;
using
CVec
=
ext_vector_t
<
CDataType
,
CWarpTensor
::
get_thread_buffer_size
()
>
;
...
@@ -49,8 +53,11 @@ struct WarpGemmImpl
...
@@ -49,8 +53,11 @@ struct WarpGemmImpl
c
.
get_thread_buffer
().
template
set_as
<
CVec
>(
I0
,
c_vec
);
c
.
get_thread_buffer
().
template
set_as
<
CVec
>(
I0
,
c_vec
);
}
}
CK_TILE_DEVICE
auto
operator
()(
const
AWarpTensor
&
a
,
const
BWarpTensor
&
b
)
const
template
<
typename
ATensor
,
typename
BTensor
>
CK_TILE_DEVICE
auto
operator
()(
const
ATensor
&
a
,
const
BTensor
&
b
)
const
{
{
static_assert
(
detail
::
is_similiar_distributed_tensor_v
<
ATensor
,
AWarpTensor
>
&&
detail
::
is_similiar_distributed_tensor_v
<
BTensor
,
BWarpTensor
>
);
CWarpTensor
c
;
CWarpTensor
c
;
using
AVec
=
ext_vector_t
<
ADataType
,
AWarpTensor
::
get_thread_buffer_size
()
>
;
using
AVec
=
ext_vector_t
<
ADataType
,
AWarpTensor
::
get_thread_buffer_size
()
>
;
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment