Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
066bde75
Commit
066bde75
authored
Dec 05, 2024
by
Astha Rai
Browse files
Merge branch 'codegen_hiprtc' of github.com:ROCm/composable_kernel into codegen_hiprtc
parents
3c3d701e
76d96973
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
17 deletions
+15
-17
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+7
-8
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
..._tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
+1
-1
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+7
-8
No files found.
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
066bde75
...
@@ -998,14 +998,14 @@ struct FmhaFwdKernel
...
@@ -998,14 +998,14 @@ struct FmhaFwdKernel
return
pad_tensor_view
(
return
pad_tensor_view
(
q_dram_naive
,
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kSubQKHeaddim
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kSubQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
sequence
<
false
,
kPadHeadDimQ
>
{});
}
}
else
else
{
{
return
pad_tensor_view
(
return
pad_tensor_view
(
q_dram_naive
,
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
sequence
<
false
,
kPadHeadDimQ
>
{});
}
}
}();
}();
const
auto
k_dram
=
[
&
]()
{
const
auto
k_dram
=
[
&
]()
{
...
@@ -1019,7 +1019,7 @@ struct FmhaFwdKernel
...
@@ -1019,7 +1019,7 @@ struct FmhaFwdKernel
return
pad_tensor_view
(
return
pad_tensor_view
(
k_dram_naive
,
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
sequence
<
false
,
kPadHeadDimQ
>
{});
}();
}();
const
auto
v_dram
=
[
&
]()
{
const
auto
v_dram
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
...
@@ -1041,7 +1041,7 @@ struct FmhaFwdKernel
...
@@ -1041,7 +1041,7 @@ struct FmhaFwdKernel
return
pad_tensor_view
(
return
pad_tensor_view
(
v_dram_transposed
,
v_dram_transposed
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
sequence
<
kPadHeadDimV
,
false
>
{});
}
}
else
else
{
{
...
@@ -1055,7 +1055,7 @@ struct FmhaFwdKernel
...
@@ -1055,7 +1055,7 @@ struct FmhaFwdKernel
return
pad_tensor_view
(
return
pad_tensor_view
(
v_dram_naive
,
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
sequence
<
false
,
kPadSeqLenK
>
{});
}
}
}();
}();
...
@@ -1097,9 +1097,8 @@ struct FmhaFwdKernel
...
@@ -1097,9 +1097,8 @@ struct FmhaFwdKernel
number
<
FmhaPipeline
::
kAlignmentBias
>
{},
number
<
FmhaPipeline
::
kAlignmentBias
>
{},
number
<
1
>
{});
number
<
1
>
{});
return
pad_tensor_view
(
bias_dram_naive
,
return
pad_tensor_view
(
bias_dram_window_lengths
,
bias_dram_naive
,
bias_dram_window_lengths
,
sequence
<
false
,
kPadSeqLenK
>
{});
sequence
<
kPadSeqLenQ
,
kPadSeqLenK
>
{});
}();
}();
return
make_tile_window
(
bias_dram
,
bias_dram_window_lengths
,
{
i_m0
,
0
});
return
make_tile_window
(
bias_dram
,
bias_dram_window_lengths
,
{
i_m0
,
0
});
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
View file @
066bde75
...
@@ -339,7 +339,7 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -339,7 +339,7 @@ struct FmhaFwdSplitKVCombineKernel
number
<
FmhaPipeline
::
kAlignmentOacc
>
{},
number
<
FmhaPipeline
::
kAlignmentOacc
>
{},
number
<
1
>
{});
number
<
1
>
{});
auto
o_acc_dram_view
=
pad_tensor_view
(
const
auto
o_acc_dram_view
=
pad_tensor_view
(
o_acc_dram_naive
,
o_acc_dram_naive
,
make_tuple
(
number
<
1
>
{},
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
make_tuple
(
number
<
1
>
{},
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
sequence
<
false
,
kPadSeqLenQ
,
kPadHeadDimV
>
{});
sequence
<
false
,
kPadSeqLenQ
,
kPadHeadDimV
>
{});
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
066bde75
...
@@ -623,14 +623,14 @@ struct FmhaFwdSplitKVKernel
...
@@ -623,14 +623,14 @@ struct FmhaFwdSplitKVKernel
return
pad_tensor_view
(
return
pad_tensor_view
(
q_dram_naive
,
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kSubQKHeaddim
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kSubQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
sequence
<
false
,
kPadHeadDimQ
>
{});
}
}
else
else
{
{
return
pad_tensor_view
(
return
pad_tensor_view
(
q_dram_naive
,
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
sequence
<
false
,
kPadHeadDimQ
>
{});
}
}
}();
}();
...
@@ -645,7 +645,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -645,7 +645,7 @@ struct FmhaFwdSplitKVKernel
return
pad_tensor_view
(
return
pad_tensor_view
(
k_dram_naive
,
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
sequence
<
false
,
kPadHeadDimQ
>
{});
};
};
const
auto
k_dram
=
[
&
]()
{
const
auto
k_dram
=
[
&
]()
{
if
constexpr
(
kIsPagedKV
)
if
constexpr
(
kIsPagedKV
)
...
@@ -678,7 +678,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -678,7 +678,7 @@ struct FmhaFwdSplitKVKernel
return
pad_tensor_view
(
return
pad_tensor_view
(
v_dram_transposed
,
v_dram_transposed
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
sequence
<
kPadHeadDimV
,
false
>
{});
}
}
else
else
{
{
...
@@ -692,7 +692,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -692,7 +692,7 @@ struct FmhaFwdSplitKVKernel
return
pad_tensor_view
(
return
pad_tensor_view
(
v_dram_naive
,
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
sequence
<
false
,
kPadSeqLenK
>
{});
}
}
};
};
const
auto
v_dram
=
[
&
]()
{
const
auto
v_dram
=
[
&
]()
{
...
@@ -804,9 +804,8 @@ struct FmhaFwdSplitKVKernel
...
@@ -804,9 +804,8 @@ struct FmhaFwdSplitKVKernel
number
<
FmhaPipeline
::
kAlignmentBias
>
{},
number
<
FmhaPipeline
::
kAlignmentBias
>
{},
number
<
1
>
{});
number
<
1
>
{});
return
pad_tensor_view
(
bias_dram_naive
,
return
pad_tensor_view
(
bias_dram_window_lengths
,
bias_dram_naive
,
bias_dram_window_lengths
,
sequence
<
false
,
kPadSeqLenK
>
{});
sequence
<
kPadSeqLenQ
,
kPadSeqLenK
>
{});
}();
}();
return
make_tile_window
(
bias_dram
,
bias_dram_window_lengths
,
{
i_m0
,
0
});
return
make_tile_window
(
bias_dram
,
bias_dram_window_lengths
,
{
i_m0
,
0
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment