Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c32d3448
Commit
c32d3448
authored
Jan 18, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
71d6ede7
402a930a
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
188 additions
and
60 deletions
+188
-60
.github/CODEOWNERS
.github/CODEOWNERS
+1
-0
Dockerfile
Dockerfile
+6
-1
Jenkinsfile
Jenkinsfile
+33
-1
LICENSE
LICENSE
+1
-1
docs/index.rst
docs/index.rst
+2
-2
docs/sphinx/requirements.in
docs/sphinx/requirements.in
+2
-2
docs/sphinx/requirements.txt
docs/sphinx/requirements.txt
+2
-2
include/ck/ck.hpp
include/ck/ck.hpp
+1
-1
include/ck/host_utility/hip_check_error.hpp
include/ck/host_utility/hip_check_error.hpp
+15
-13
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+8
-5
include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp
...or_operation/gpu/device/impl/device_contraction_utils.hpp
+6
-4
profiler/include/profiler/profile_gemm_impl.hpp
profiler/include/profiler/profile_gemm_impl.hpp
+5
-3
profiler/include/profiler/profile_gemm_splitk_impl.hpp
profiler/include/profiler/profile_gemm_splitk_impl.hpp
+7
-4
profiler/include/profiler/profile_grouped_gemm_impl.hpp
profiler/include/profiler/profile_grouped_gemm_impl.hpp
+7
-4
profiler/src/profile_gemm.cpp
profiler/src/profile_gemm.cpp
+14
-2
profiler/src/profile_gemm_splitk.cpp
profiler/src/profile_gemm_splitk.cpp
+15
-2
profiler/src/profile_grouped_gemm.cpp
profiler/src/profile_grouped_gemm.cpp
+31
-7
test/gemm_split_k/test_gemm_splitk_util.hpp
test/gemm_split_k/test_gemm_splitk_util.hpp
+16
-3
test/grouped_gemm/test_grouped_gemm_util.hpp
test/grouped_gemm/test_grouped_gemm_util.hpp
+16
-3
No files found.
.github/CODEOWNERS
View file @
c32d3448
* @zjing14 @asroy @junliume @illsilin @carlushuang
# Documentation files
docs/* @saadrahim @LisaDelaney
*.md @saadrahim @LisaDelaney
...
...
Dockerfile
View file @
c32d3448
...
...
@@ -74,7 +74,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
apt-get clean
&&
\
rm
-rf
/var/lib/apt/lists/
*
#Install
latest version of cmake
#Install
ninja build tracing tools
RUN
wget
-qO
/usr/local/bin/ninja.gz https://github.com/ninja-build/ninja/releases/latest/download/ninja-linux.zip
RUN
gunzip
/usr/local/bin/ninja.gz
RUN
chmod
a+x /usr/local/bin/ninja
...
...
@@ -82,6 +82,11 @@ RUN git clone https://github.com/nico/ninjatracing.git
# Update the cmake to the latest version
RUN
pip
install
--upgrade
cmake
==
3.27.5
#Install latest cppcheck
RUN
git clone https://github.com/danmar/cppcheck.git
&&
\
cd
cppcheck
&&
mkdir
build
&&
cd
build
&&
cmake ..
&&
cmake
--build
.
WORKDIR
/
# Setup ubsan environment to printstacktrace
RUN
ln
-s
/usr/bin/llvm-symbolizer-3.8 /usr/local/bin/llvm-symbolizer
ENV
UBSAN_OPTIONS=print_stacktrace=1
...
...
Jenkinsfile
View file @
c32d3448
...
...
@@ -304,7 +304,7 @@ def buildHipClangJob(Map conf=[:]){
gitStatusWrapper
(
credentialsId:
"${status_wrapper_creds}"
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCm'
,
repo:
'composable_kernel'
)
{
withDockerContainer
(
image:
image
,
args:
dockerOpts
+
' -v=/var/jenkins/:/var/jenkins'
)
{
timeout
(
time:
5
,
unit:
'HOURS'
)
timeout
(
time:
20
,
unit:
'HOURS'
)
{
cmake_build
(
conf
)
}
...
...
@@ -709,6 +709,10 @@ pipeline {
name:
"USE_SCCACHE"
,
defaultValue:
true
,
description:
"Use the sccache for building CK (default: ON)"
)
booleanParam
(
name:
"RUN_CPPCHECK"
,
defaultValue:
false
,
description:
"Run the cppcheck static analysis (default: OFF)"
)
}
environment
{
dbuser
=
"${dbuser}"
...
...
@@ -735,7 +739,35 @@ pipeline {
}
stage
(
"Static checks"
)
{
parallel
{
stage
(
'Clang Format and Cppcheck'
)
{
when
{
beforeAgent
true
expression
{
params
.
RUN_CPPCHECK
.
toBoolean
()
}
}
agent
{
label
rocmnode
(
"nogpu"
)
}
environment
{
execute_cmd
=
"find .. -not -path \'*.git*\' -iname \'*.h\' \
-o -not -path \'*.git*\' -iname \'*.hpp\' \
-o -not -path \'*.git*\' -iname \'*.cpp\' \
-o -iname \'*.h.in\' \
-o -iname \'*.hpp.in\' \
-o -iname \'*.cpp.in\' \
-o -iname \'*.cl\' \
| grep -v 'build/' \
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\' && \
/cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include --file-filter=*.cpp --enable=all --output-file=ck_cppcheck.log"
}
steps
{
buildHipClangJobAndReboot
(
setup_cmd:
""
,
build_cmd:
""
,
execute_cmd:
execute_cmd
,
no_reboot:
true
)
archiveArtifacts
"build/ck_cppcheck.log"
cleanWs
()
}
}
stage
(
'Clang Format'
)
{
when
{
beforeAgent
true
expression
{
!
params
.
RUN_CPPCHECK
.
toBoolean
()
}
}
agent
{
label
rocmnode
(
"nogpu"
)
}
environment
{
execute_cmd
=
"find .. -not -path \'*.git*\' -iname \'*.h\' \
...
...
LICENSE
View file @
c32d3448
...
...
@@ -7,7 +7,7 @@ Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou)
Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan)
SPDX-License-Identifier: MIT
Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
...
...
docs/index.rst
View file @
c32d3448
...
...
@@ -34,6 +34,6 @@ The CK documentation is structured as follows:
* :ref:`contributing-to`
To contribute to the documentation refer to `Contributing to ROCm <https://rocm.docs.amd.com/en/latest/contribute/index.
md
>`_.
To contribute to the documentation refer to `Contributing to ROCm <https://rocm.docs.amd.com/en/latest/contribute/index.
html
>`_.
You can find licensing information
at
the `Licensing <https://rocm.docs.amd.com/en/latest/about/license.
md
>`_ page.
You can find licensing information
on
the `Licensing <https://rocm.docs.amd.com/en/latest/about/license.
html
>`_ page.
docs/sphinx/requirements.in
View file @
c32d3448
rocm-docs-core==0.3
0.3
sphinxcontrib-bibtex==2.6.
1
rocm-docs-core==0.3
1.0
sphinxcontrib-bibtex==2.6.
2
docs/sphinx/requirements.txt
View file @
c32d3448
...
...
@@ -113,7 +113,7 @@ requests==2.31.0
# via
# pygithub
# sphinx
rocm-docs-core==0.3
0.3
rocm-docs-core==0.3
1.0
# via -r requirements.in
six==1.16.0
# via
...
...
@@ -149,7 +149,7 @@ sphinx-notfound-page==0.8.3
# via rocm-docs-core
sphinxcontrib-applehelp==1.0.4
# via sphinx
sphinxcontrib-bibtex==2.6.
1
sphinxcontrib-bibtex==2.6.
2
# via -r requirements.in
sphinxcontrib-devhelp==1.0.2
# via sphinx
...
...
include/ck/ck.hpp
View file @
c32d3448
...
...
@@ -218,7 +218,7 @@
// denorm test fix, required to work around dissue
#ifndef CK_WORKAROUND_DENORM_FIX
#define CK_WORKAROUND_DENORM_FIX 0
#el
if
#el
se
// enable only on MI200
#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
#endif // CK_WORKAROUND_DENORM_FIX
...
...
include/ck/host_utility/hip_check_error.hpp
View file @
c32d3448
...
...
@@ -12,21 +12,23 @@ inline void hip_check_error(hipError_t x)
if
(
x
!=
hipSuccess
)
{
std
::
ostringstream
ss
;
ss
<<
"HIP runtime error: "
<<
hipGetErrorString
(
x
)
<<
". "
<<
__FILE__
<<
": "
<<
__LINE__
<<
"in function: "
<<
__func__
;
ss
<<
"HIP runtime error: "
<<
hipGetErrorString
(
x
)
<<
". "
<<
"hip_check_error.hpp"
<<
": "
<<
__LINE__
<<
"in function: "
<<
__func__
;
throw
std
::
runtime_error
(
ss
.
str
());
}
}
#define HIP_CHECK_ERROR(retval_or_funcall) \
do \
{ \
hipError_t _tmpVal = retval_or_funcall; \
if(_tmpVal != hipSuccess) \
{ \
std::ostringstream ostr; \
ostr << "HIP Function Failed (" << __FILE__ << "," << __LINE__ << ") " \
<< hipGetErrorString(_tmpVal); \
throw std::runtime_error(ostr.str()); \
} \
#define HIP_CHECK_ERROR(retval_or_funcall) \
do \
{ \
hipError_t _tmpVal = retval_or_funcall; \
if(_tmpVal != hipSuccess) \
{ \
std::ostringstream ostr; \
ostr << "HIP Function Failed (" \
<< "hip_check_error.hpp" \
<< "," << __LINE__ << ") " << hipGetErrorString(_tmpVal); \
throw std::runtime_error(ostr.str()); \
} \
} while(0)
include/ck/host_utility/kernel_launch.hpp
View file @
c32d3448
...
...
@@ -30,7 +30,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
block_dim
.
y
,
block_dim
.
z
);
printf
(
"Warm up
1
time
\n
"
);
printf
(
"Warm up
%d
time
s
\n
"
,
stream_config
.
cold_niters_
);
#endif
// warm up
for
(
int
i
=
0
;
i
<
stream_config
.
cold_niters_
;
++
i
)
...
...
@@ -103,14 +103,17 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
block_dim
.
y
,
block_dim
.
z
);
printf
(
"Warm up
1
time
\n
"
);
printf
(
"Warm up
%d
time
s
\n
"
,
stream_config
.
cold_niters_
);
#endif
// warm up
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
for
(
int
i
=
0
;
i
<
stream_config
.
cold_niters_
;
++
i
)
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
}
const
int
nrepeat
=
10
;
const
int
nrepeat
=
stream_config
.
nrepeat_
;
#if DEBUG_LOG
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
#endif
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp
View file @
c32d3448
...
...
@@ -35,15 +35,17 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
if
(
lengths
.
size
()
!=
NumDim1
+
NumDim2
)
{
std
::
ostringstream
err
;
err
<<
"Incorrect number of lengths in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
err
<<
"Incorrect number of lengths in "
<<
"device_contraction_utils.hpp"
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
if
(
strides
.
size
()
!=
NumDim1
+
NumDim2
)
{
std
::
ostringstream
err
;
err
<<
"Incorrect number of strides in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
err
<<
"Incorrect number of strides in "
<<
"device_contraction_utils.hpp"
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
...
...
profiler/include/profiler/profile_gemm_impl.hpp
View file @
c32d3448
...
...
@@ -42,7 +42,9 @@ int profile_gemm_impl(int do_verification,
int
K
,
int
StrideA
,
int
StrideB
,
int
StrideC
)
int
StrideC
,
int
n_warmup
,
int
n_iter
)
{
bool
pass
=
true
;
...
...
@@ -165,8 +167,8 @@ int profile_gemm_impl(int do_verification,
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
,
0
,
10
,
50
});
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
,
0
,
n_warmup
,
n_iter
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
...
...
profiler/include/profiler/profile_gemm_splitk_impl.hpp
View file @
c32d3448
...
...
@@ -42,7 +42,9 @@ bool profile_gemm_splitk_impl(int do_verification,
int
StrideA
,
int
StrideB
,
int
StrideC
,
int
KBatch
)
int
KBatch
,
int
n_warmup
,
int
n_iter
)
{
bool
pass
=
true
;
...
...
@@ -177,7 +179,8 @@ bool profile_gemm_splitk_impl(int do_verification,
// re-init C to zero before profiling next kernel
c_device_buf
.
SetZero
();
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
,
0
,
n_warmup
,
n_iter
});
if
(
do_verification
)
{
...
...
@@ -200,8 +203,8 @@ bool profile_gemm_splitk_impl(int do_verification,
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
,
0
,
n_warmup
,
n_iter
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
...
...
profiler/include/profiler/profile_grouped_gemm_impl.hpp
View file @
c32d3448
...
...
@@ -42,7 +42,9 @@ bool profile_grouped_gemm_impl(int do_verification,
const
std
::
vector
<
int
>&
StrideAs
,
const
std
::
vector
<
int
>&
StrideBs
,
const
std
::
vector
<
int
>&
StrideCs
,
int
kbatch
=
1
)
int
kbatch
=
1
,
int
n_warmup
=
1
,
int
n_iter
=
10
)
{
bool
pass
=
true
;
...
...
@@ -261,7 +263,8 @@ bool profile_grouped_gemm_impl(int do_verification,
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
c_device_buf
[
i
]
->
SetZero
();
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
,
0
,
n_warmup
,
n_iter
});
if
(
do_verification
)
{
...
...
@@ -307,8 +310,8 @@ bool profile_grouped_gemm_impl(int do_verification,
pass
=
pass
&&
instance_pass
;
}
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
,
0
,
n_warmup
,
n_iter
});
if
(
time_kernel
)
{
...
...
profiler/src/profile_gemm.cpp
View file @
c32d3448
...
...
@@ -42,12 +42,15 @@ static void print_helper_msg()
<<
"arg6: print tensor value (0: no; 1: yes)
\n
"
<<
"arg7: time kernel (0: no, 1: yes)
\n
"
<<
"arg8 to 13: M, N, K, StrideA, StrideB, StrideC
\n
"
<<
"optional:
\n
"
<<
"arg14: number of warm-up cycles (default 1)
\n
"
<<
"arg15: number of iterations (default 10)
\n
"
<<
std
::
endl
;
}
int
profile_gemm
(
int
argc
,
char
*
argv
[])
{
if
(
argc
!=
14
)
if
(
argc
!=
14
&&
argc
!=
16
)
{
print_helper_msg
();
exit
(
1
);
...
...
@@ -68,6 +71,13 @@ int profile_gemm(int argc, char* argv[])
const
int
StrideB
=
std
::
stoi
(
argv
[
12
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
13
]);
int
n_warmup
=
1
;
int
n_iter
=
10
;
if
(
argc
==
16
)
{
n_warmup
=
std
::
stoi
(
argv
[
14
]);
n_iter
=
std
::
stoi
(
argv
[
15
]);
}
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
#ifdef CK_ENABLE_BF16
...
...
@@ -120,7 +130,9 @@ int profile_gemm(int argc, char* argv[])
K
,
(
StrideA
<
0
)
?
DefaultStrideA
:
StrideA
,
(
StrideB
<
0
)
?
DefaultStrideB
:
StrideB
,
(
StrideC
<
0
)
?
DefaultStrideC
:
StrideC
);
(
StrideC
<
0
)
?
DefaultStrideC
:
StrideC
,
n_warmup
,
n_iter
);
return
pass
?
0
:
1
;
};
...
...
profiler/src/profile_gemm_splitk.cpp
View file @
c32d3448
...
...
@@ -33,7 +33,7 @@ enum struct GemmDataType
int
profile_gemm_splitk
(
int
argc
,
char
*
argv
[])
{
if
(
argc
!=
15
)
if
(
argc
!=
15
&&
argc
!=
17
)
{
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: f16, "
...
...
@@ -48,6 +48,9 @@ int profile_gemm_splitk(int argc, char* argv[])
printf
(
"arg7: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg8 to 13: M, N, K, StrideA, StrideB, StrideC
\n
"
);
printf
(
"arg14: split k into mulitiple batch
\n
"
);
printf
(
"optional:
\n
"
);
printf
(
"arg15: number of warm-up cycles (default 1)
\n
"
);
printf
(
"arg16: number of iterations (default 10)
\n
"
);
exit
(
1
);
}
...
...
@@ -67,6 +70,14 @@ int profile_gemm_splitk(int argc, char* argv[])
const
int
StrideC
=
std
::
stoi
(
argv
[
13
]);
const
int
KBatch
=
std
::
stoi
(
argv
[
14
]);
int
n_warmup
=
1
;
int
n_iter
=
10
;
if
(
argc
==
17
)
{
n_warmup
=
std
::
stoi
(
argv
[
15
]);
n_iter
=
std
::
stoi
(
argv
[
16
]);
}
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
#if defined CK_ENABLE_FP8
...
...
@@ -117,7 +128,9 @@ int profile_gemm_splitk(int argc, char* argv[])
(
StrideA
<
0
)
?
DefaultStrideA
:
StrideA
,
(
StrideB
<
0
)
?
DefaultStrideB
:
StrideB
,
(
StrideC
<
0
)
?
DefaultStrideC
:
StrideC
,
KBatch
);
KBatch
,
n_warmup
,
n_iter
);
return
pass
?
0
:
1
;
};
...
...
profiler/src/profile_grouped_gemm.cpp
View file @
c32d3448
...
...
@@ -69,7 +69,10 @@ int profile_grouped_gemm(int argc, char* argv[])
<<
"arg7: time kernel (0=n0, 1=yes)
\n
"
<<
"arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)
\n
"
<<
"arg15: kbatch value (default 4)
\n
"
<<
"arg15: kbatch value (default 1)
\n
"
<<
"optional:
\n
"
<<
"arg16: number of warm-up cycles (default 1)
\n
"
<<
"arg17: number of iterations (default 10)
\n
"
<<
std
::
endl
;
exit
(
1
);
...
...
@@ -90,6 +93,15 @@ int profile_grouped_gemm(int argc, char* argv[])
const
auto
StrideBs
=
argToIntArray
(
argv
[
12
]);
const
auto
StrideCs
=
argToIntArray
(
argv
[
13
]);
const
int
kbatch
=
argc
==
15
?
std
::
stoi
(
argv
[
14
])
:
1
;
int
n_warmup
=
1
;
int
n_iter
=
10
;
if
(
argc
==
17
)
{
n_warmup
=
std
::
stoi
(
argv
[
16
]);
n_iter
=
std
::
stoi
(
argv
[
17
]);
}
#ifdef CK_ENABLE_FP16
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
...
...
@@ -109,7 +121,9 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
);
kbatch
,
n_warmup
,
n_iter
);
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
...
...
@@ -129,7 +143,9 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
);
kbatch
,
n_warmup
,
n_iter
);
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
...
...
@@ -149,7 +165,9 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
);
kbatch
,
n_warmup
,
n_iter
);
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
...
...
@@ -169,7 +187,9 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
);
kbatch
,
n_warmup
,
n_iter
);
}
else
if
(
data_type
==
GemmDataType
::
F8_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
...
...
@@ -189,7 +209,9 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
);
kbatch
,
n_warmup
,
n_iter
);
}
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
...
...
@@ -209,7 +231,9 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
);
kbatch
,
n_warmup
,
n_iter
);
}
else
{
...
...
test/gemm_split_k/test_gemm_splitk_util.hpp
View file @
c32d3448
...
...
@@ -60,7 +60,9 @@ class TestGemmSplitK : public testing::Test
const
int
StrideA
,
const
int
StrideB
,
const
int
StrideC
,
int
kbatch
=
1
)
int
kbatch
=
1
,
int
n_warmup
=
1
,
int
n_iter
=
10
)
{
bool
pass
=
ck
::
profiler
::
profile_gemm_splitk_impl
<
ADataType
,
BDataType
,
...
...
@@ -68,8 +70,19 @@ class TestGemmSplitK : public testing::Test
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
verify_
,
init_method_
,
log_
,
bench_
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
kbatch
);
CLayout
>
(
verify_
,
init_method_
,
log_
,
bench_
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
kbatch
,
n_warmup
,
n_iter
);
EXPECT_TRUE
(
pass
);
}
};
...
...
test/grouped_gemm/test_grouped_gemm_util.hpp
View file @
c32d3448
...
...
@@ -63,7 +63,9 @@ class TestGroupedGemm : public testing::TestWithParam<int>
const
std
::
vector
<
int
>&
StrideAs
,
const
std
::
vector
<
int
>&
StrideBs
,
const
std
::
vector
<
int
>&
StrideCs
,
int
kbatch
=
1
)
int
kbatch
=
1
,
int
n_warmup
=
1
,
int
n_iter
=
10
)
{
bool
pass
=
ck
::
profiler
::
profile_grouped_gemm_impl
<
ADataType
,
BDataType
,
...
...
@@ -71,8 +73,19 @@ class TestGroupedGemm : public testing::TestWithParam<int>
float
,
ALayout
,
BLayout
,
ELayout
>
(
verify_
,
init_method_
,
log_
,
bench_
,
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
);
ELayout
>
(
verify_
,
init_method_
,
log_
,
bench_
,
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
,
n_warmup
,
n_iter
);
EXPECT_TRUE
(
pass
);
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment