Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
abd2755a
Commit
abd2755a
authored
Jan 06, 2025
by
ThomasNing
Browse files
Merge branch 'develop' into moe_cross_reduce
parents
b74918bc
888317e6
Changes
166
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
41 additions
and
9 deletions
+41
-9
profiler/src/profile_gemm_universal.cpp
profiler/src/profile_gemm_universal.cpp
+14
-3
profiler/src/profile_gemm_universal_streamk.cpp
profiler/src/profile_gemm_universal_streamk.cpp
+19
-2
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+2
-2
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
+2
-1
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
+3
-1
test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp
test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp
+1
-0
No files found.
profiler/src/profile_gemm_universal.cpp
View file @
abd2755a
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_gemm_universal_impl.hpp"
#include "profiler_operation_registry.hpp"
...
...
@@ -27,6 +27,8 @@ enum struct GemmDataType
F16_F8_F16
,
// 5
F16_F16_F16_F8
,
// 6
F8_F8_BF16
,
// 7
F16_I4_F16
,
// 8
BF16_I4_BF16
,
// 9
};
#define OP_NAME "gemm_universal"
...
...
@@ -39,7 +41,7 @@ int profile_gemm_universal(int argc, char* argv[])
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
"f16->f8; 7: f8->bf16, "
"comp f8
)
\n
"
);
"comp f8
; 8: f16@i4; 9: bf16@i4
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 2: A[k, m] * B[k, n] = C[m, n];
\n
"
);
...
...
@@ -103,6 +105,7 @@ int profile_gemm_universal(int argc, char* argv[])
using
BF16
=
ck
::
bhalf_t
;
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
using
F8
=
ck
::
f8_t
;
using
I4
=
ck
::
pk_i4_t
;
#endif
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
...
...
@@ -207,6 +210,14 @@ int profile_gemm_universal(int argc, char* argv[])
{
return
profile
(
F8
{},
F8
{},
F8
{},
F32
{},
BF16
{},
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
F16_I4_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
F16
{},
I4
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
BF16_I4_BF16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
BF16
{},
I4
{},
BF16
{},
F32
{},
BF16
{},
Row
{},
Col
{},
Row
{});
}
#endif
else
{
...
...
profiler/src/profile_gemm_universal_streamk.cpp
100755 → 100644
View file @
abd2755a
...
...
@@ -83,8 +83,9 @@ int profile_gemm_universal_streamk(int argc, char* argv[])
rotating
=
std
::
stoull
(
argv
[
18
])
*
1024
*
1024
;
}
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
using
F8
=
ck
::
f8_t
;
...
...
@@ -165,6 +166,22 @@ int profile_gemm_universal_streamk(int argc, char* argv[])
return
profile
(
F8
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{});
}
#endif
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
profile
(
BF16
{},
BF16
{},
F32
{},
BF16
{},
Row
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
BF16
{},
BF16
{},
F32
{},
BF16
{},
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
return
profile
(
BF16
{},
BF16
{},
F32
{},
BF16
{},
Col
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
return
profile
(
BF16
{},
BF16
{},
F32
{},
BF16
{},
Col
{},
Col
{},
Row
{});
}
else
{
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
...
...
script/cmake-ck-dev.sh
View file @
abd2755a
...
...
@@ -15,9 +15,9 @@ else
fi
cmake
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm
/
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_CXX_FLAGS
=
"-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0
-fPIE
-Wno-gnu-line-marker"
\
-D
CMAKE_CXX_FLAGS
=
"-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
BUILD_DEV
=
ON
\
-D
GPU_TARGETS
=
$GPU_TARGETS
\
...
...
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
View file @
abd2755a
...
...
@@ -93,7 +93,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
batch_count
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
,
args
.
batch_count
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
s
.
log_level_
>
0
)
...
...
@@ -186,6 +186,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
c_ptr
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
k_batch
=
1
;
args
.
M
=
M
;
args
.
N
=
N
;
args
.
K
=
K
;
...
...
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
View file @
abd2755a
...
...
@@ -74,7 +74,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>>
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
args
.
K
);
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
K_Tile
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
K_split
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
ck_tile
::
TailNumber
tail_num
=
BaseGemmPipeline
::
GetBlockLoopTailNum
(
num_loop
);
...
...
test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp
View file @
abd2755a
...
...
@@ -64,6 +64,7 @@ using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWC, GKYXC, GNHWK>,
std
::
tuple
<
int8_t
,
NHWGC
,
GKYXC
,
NHWGK
>
,
std
::
tuple
<
float
,
NGCHW
,
GKYXC
,
NGKHW
>
,
std
::
tuple
<
ck
::
half_t
,
NGCHW
,
GKYXC
,
NGKHW
>
,
std
::
tuple
<
ck
::
bhalf_t
,
NGCHW
,
GKYXC
,
NGKHW
>
,
std
::
tuple
<
int8_t
,
NGCHW
,
GKYXC
,
NGKHW
>>
;
using
KernelTypes3d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
GNDHWC
,
GKZYXC
,
GNDHWK
>
,
...
...
Prev
1
…
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment