Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ff5115be
Commit
ff5115be
authored
Jan 03, 2025
by
Aleksander Dudek
Browse files
[CK_TILE] Add GetName for grouped gemm
parent
0c4cf86e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
64 additions
and
6 deletions
+64
-6
example/ck_tile/16_batched_gemm/batched_gemm.cpp
example/ck_tile/16_batched_gemm/batched_gemm.cpp
+2
-2
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
+2
-2
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
+25
-1
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
+35
-1
No files found.
example/ck_tile/16_batched_gemm/batched_gemm.cpp
View file @
ff5115be
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
...
@@ -91,7 +91,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
...
@@ -91,7 +91,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
if
(
s
.
log_level_
>
0
)
if
(
s
.
log_level_
>
0
)
{
{
std
::
cout
<<
"Launching kernel with args:"
std
::
cout
<<
"Launching kernel
: "
<<
Kernel
::
GetName
()
<<
"
with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
<<
std
::
endl
;
...
...
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
View file @
ff5115be
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
...
@@ -128,7 +128,7 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
...
@@ -128,7 +128,7 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
if
(
s
.
log_level_
>
0
)
if
(
s
.
log_level_
>
0
)
{
{
std
::
cout
<<
"Launching kernel with args:"
std
::
cout
<<
"Launching kernel
: "
<<
GroupedGemmKernel
::
GetName
()
<<
"
with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
<<
std
::
endl
;
...
...
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
View file @
ff5115be
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -57,6 +57,30 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
...
@@ -57,6 +57,30 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using
BLayout
=
typename
Base
::
BLayout
;
using
BLayout
=
typename
Base
::
BLayout
;
using
CLayout
=
typename
Base
::
CLayout
;
using
CLayout
=
typename
Base
::
CLayout
;
CK_TILE_HOST
static
std
::
string
GetName
()
{
#define _SS_ std::string
#define _TS_ std::to_string
// clang-format off
using
P_
=
GemmPipeline
;
auto
prec_str
=
[
&
]
()
{
std
::
string
base_str
=
_SS_
(
Base
::
template
t2s
<
ADataType
>
::
name
);
if
(
!
std
::
is_same_v
<
ADataType
,
BDataType
>
)
{
base_str
+=
_SS_
(
"_"
)
+
_SS_
(
Base
::
template
t2s
<
BDataType
>
::
name
);
}
return
base_str
;
}();
return
_SS_
(
"gemm_batched_"
)
+
_SS_
(
prec_str
)
+
"_"
+
_TS_
(
P_
::
kMPerBlock
)
+
"x"
+
_TS_
(
P_
::
kNPerBlock
)
+
"x"
+
_TS_
(
P_
::
kKPerBlock
)
+
"_"
+
_TS_
(
P_
::
VectorSizeA
)
+
"x"
+
_TS_
(
P_
::
VectorSizeB
)
+
"x"
+
_TS_
(
P_
::
VectorSizeC
)
+
"_"
+
_TS_
(
P_
::
kPadM
)
+
"x"
+
_TS_
(
P_
::
kPadN
)
+
"x"
+
_TS_
(
P_
::
kPadK
);
#undef _SS_
#undef _TS_
// clang-format on
}
struct
BatchedGemmKernelArgs
:
GemmKernelArgs
struct
BatchedGemmKernelArgs
:
GemmKernelArgs
{
{
index_t
batch_stride_A
;
index_t
batch_stride_A
;
...
...
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
View file @
ff5115be
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -44,6 +44,40 @@ struct GroupedGemmKernel
...
@@ -44,6 +44,40 @@ struct GroupedGemmKernel
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
template
<
>
struct
t2s
<
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
t2s
<
int8_t
>
{
static
constexpr
const
char
*
name
=
"int8"
;
};
// clang-format on
CK_TILE_HOST
static
std
::
string
GetName
()
{
#define _SS_ std::string
#define _TS_ std::to_string
// clang-format off
using
P_
=
GemmPipeline
;
auto
prec_str
=
[
&
]
()
{
std
::
string
base_str
=
_SS_
(
t2s
<
ADataType
>::
name
);
if
(
!
std
::
is_same_v
<
ADataType
,
BDataType
>
)
{
base_str
+=
_SS_
(
"_"
)
+
_SS_
(
t2s
<
BDataType
>::
name
);
}
return
base_str
;
}();
return
_SS_
(
"gemm_grouped_"
)
+
_SS_
(
prec_str
)
+
"_"
+
_TS_
(
P_
::
kMPerBlock
)
+
"x"
+
_TS_
(
P_
::
kNPerBlock
)
+
"x"
+
_TS_
(
P_
::
kKPerBlock
)
+
"_"
+
_TS_
(
P_
::
VectorSizeA
)
+
"x"
+
_TS_
(
P_
::
VectorSizeB
)
+
"x"
+
_TS_
(
P_
::
VectorSizeC
)
+
"_"
+
_TS_
(
P_
::
kPadM
)
+
"x"
+
_TS_
(
P_
::
kPadN
)
+
"x"
+
_TS_
(
P_
::
kPadK
);
#undef _SS_
#undef _TS_
// clang-format on
}
struct
GemmTransKernelArg
struct
GemmTransKernelArg
{
{
GroupedGemmHostArgs
group_karg
;
GroupedGemmHostArgs
group_karg
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment