Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
08ab9cfa
Commit
08ab9cfa
authored
Feb 28, 2024
by
aska-0096
Browse files
fix a typo of name
parent
8a6e65a3
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
8 additions
and
8 deletions
+8
-8
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc
...gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc
+2
-2
example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc
...softmax_gemm/run_grouped_query_attention_forward_wmma.inc
+2
-2
example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc
...e_softmax_gemm/run_multi_query_attention_forward_wmma.inc
+2
-2
example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc
...tched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc
+2
-2
No files found.
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc
View file @
08ab9cfa
...
...
@@ -182,9 +182,9 @@ int run(int argc, char* argv[])
printf
(
"Verification: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
);
// TODO ANT: replace array with vector?
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceMHAFactory
>
,
1
>
{}([
&
](
auto
i
)
->
void
{
const
auto
device_
conv_
mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
const
auto
device_mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_
conv_
mha_instance
)
>
;
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_mha_instance
)
>
;
auto
gemm
=
DeviceMHAInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
...
...
example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc
View file @
08ab9cfa
...
...
@@ -185,9 +185,9 @@ int run(int argc, char* argv[])
printf
(
"Verification: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
);
// TODO ANT: replace array with vector?
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceMHAFactory
>
,
1
>
{}([
&
](
auto
i
)
->
void
{
const
auto
device_
conv_
mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
const
auto
device_mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_
conv_
mha_instance
)
>
;
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_mha_instance
)
>
;
auto
gemm
=
DeviceMHAInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
...
...
example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc
View file @
08ab9cfa
...
...
@@ -185,9 +185,9 @@ int run(int argc, char* argv[])
printf
(
"Verification: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
);
// TODO ANT: replace array with vector?
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceMHAFactory
>
,
1
>
{}([
&
](
auto
i
)
->
void
{
const
auto
device_
conv_
mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
const
auto
device_mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_
conv_
mha_instance
)
>
;
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_mha_instance
)
>
;
auto
gemm
=
DeviceMHAInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
...
...
example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc
View file @
08ab9cfa
...
...
@@ -215,9 +215,9 @@ int run(int argc, char* argv[])
printf
(
"Verification: %s
\n
"
,
do_verification
?
"ON"
:
"OFF"
);
// TODO ANT: replace array with vector?
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceMHAFactory
>
,
1
>
{}([
&
](
auto
i
)
->
void
{
const
auto
device_
conv_
mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
const
auto
device_mha_instance
=
std
::
get
<
i
>
(
DeviceMHAFactory
{});
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_
conv_
mha_instance
)
>
;
using
DeviceMHAInstance
=
ck
::
remove_cvref_t
<
decltype
(
device_mha_instance
)
>
;
auto
gemm
=
DeviceMHAInstance
{};
auto
invoker
=
gemm
.
MakeSelfAttnInvoker
();
auto
argument
=
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment