Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d83e2d25
"docs/source/en/api/pipelines/overview.md" did not exist on "8aac1f99d7af5873db7d23c07fba370d0f5061a6"
Unverified
Commit
d83e2d25
authored
Dec 14, 2024
by
feli
Committed by
GitHub
Dec 14, 2024
Browse files
Merge branch 'develop' into lynm_fwd_bias
parents
006df96c
d68974a5
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
205 additions
and
147 deletions
+205
-147
Dockerfile
Dockerfile
+2
-1
example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp
example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp
+2
-2
example/ck_tile/12_smoothquant/example_smoothquant.cpp
example/ck_tile/12_smoothquant/example_smoothquant.cpp
+26
-18
example/ck_tile/12_smoothquant/smoothquant.cpp
example/ck_tile/12_smoothquant/smoothquant.cpp
+26
-18
include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp
...nsor_operation/gpu/device/device_batched_gemm_multi_d.hpp
+2
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
...e/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
+29
-16
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
+8
-8
include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp
...ude/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp
+14
-6
library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp
...ce_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp
+3
-0
library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_f8_f8_bf16/device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp
...device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp
+2
-0
profiler/include/profiler/profile_gemm_universal_batched_impl.hpp
.../include/profiler/profile_gemm_universal_batched_impl.hpp
+80
-68
profiler/src/profile_gemm_universal_batched.cpp
profiler/src/profile_gemm_universal_batched.cpp
+11
-9
No files found.
Dockerfile
View file @
d83e2d25
...
...
@@ -64,6 +64,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
nano
\
zlib1g-dev
\
zip
\
libzstd-dev
\
openssh-server
\
clang-format-12
\
kmod
&&
\
...
...
@@ -93,7 +94,7 @@ RUN pip install --upgrade cmake==3.27.5 && \
dpkg -i dumb-init_*.deb && rm dumb-init_*.deb && \
# Install packages for processing the performance results
pip3 install --upgrade pip && \
pip3 install sqlalchemy==1.4.46 pymysql pandas==2.
0
.3 setuptools-rust sshtunnel==0.4.0 && \
pip3 install sqlalchemy==1.4.46 pymysql pandas==2.
2
.3 setuptools-rust sshtunnel==0.4.0 && \
# Add render group
groupadd -f render && \
# Install the new rocm-cmake version
...
...
example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp
View file @
d83e2d25
...
...
@@ -78,14 +78,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferDstScalarPerVector_AK1
1
,
// ABlockLdsExtraM
0
,
// ABlockLdsExtraM
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_BK0_N_BK1
S
<
1
,
0
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
8
,
// BBlockTransferDstScalarPerVector_BK1
1
,
// BBlockLdsExtraN
0
,
// BBlockLdsExtraN
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
...
...
example/ck_tile/12_smoothquant/example_smoothquant.cpp
View file @
d83e2d25
...
...
@@ -35,7 +35,8 @@ auto create_args(int argc, char* argv[])
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"m"
,
"3328"
,
"m dimension"
)
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to n"
)
.
insert
(
"x_stride"
,
"-1"
,
"input stride per row, if -1 then equal to n"
)
.
insert
(
"y_stride"
,
"-1"
,
"output stride per row, if -1 then equal to n"
)
.
insert
(
"e"
,
"1e-5"
,
"epsilon"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"prec"
,
"fp16"
,
"precision"
)
...
...
@@ -49,11 +50,14 @@ auto create_args(int argc, char* argv[])
template
<
typename
DataType
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
n
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
if
(
stride
<
0
)
stride
=
n
;
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
n
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
x_stride
=
arg_parser
.
get_int
(
"x_stride"
);
if
(
x_stride
<
0
)
x_stride
=
n
;
ck_tile
::
index_t
y_stride
=
arg_parser
.
get_int
(
"y_stride"
);
if
(
y_stride
<
0
)
y_stride
=
n
;
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
...
...
@@ -68,14 +72,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
using
ComputeDataType
=
float
;
// host verify
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
x_
stride
,
1
});
ck_tile
::
HostTensor
<
XScaleDataType
>
xscale_host
({
n
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_ref
({
m
},
{
1
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_dev
({
m
},
{
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_ref
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_dev
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_ref
({
m
,
n
},
{
y_
stride
,
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_dev
({
m
,
n
},
{
y_
stride
,
1
});
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
XScaleDataType
>
{
1e-3
,
.5
f
}(
xscale_host
);
...
...
@@ -116,7 +120,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf
.
GetDeviceBuffer
(),
m
,
n
,
stride
};
x_stride
,
y_stride
};
auto
kargs
=
Kernel
::
MakeKargs
(
args
);
...
...
@@ -133,7 +138,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
do_validation
)
{
using
YDataType
=
ComputeDataType
;
ck_tile
::
HostTensor
<
ComputeDataType
>
y_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
ComputeDataType
>
y_host
({
m
,
n
},
{
y_
stride
,
1
});
// smooth outlier
{
auto
f
=
[
&
](
auto
n_
)
{
...
...
@@ -183,7 +188,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf
.
FromDevice
(
qy_host_dev
.
data
());
auto
[
rtol
,
atol
]
=
get_elimit
<
QYDataType
>
();
if
(
stride
==
n
)
if
(
y_
stride
==
n
)
{
pass
=
ck_tile
::
check_err
(
qy_host_dev
,
qy_host_ref
,
...
...
@@ -195,10 +200,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
for
(
int
i_r
=
0
;
i_r
<
m
;
i_r
++
)
{
std
::
vector
<
QYDataType
>
qy_host_dev_row
(
qy_host_dev
.
begin
()
+
i_r
*
stride
,
qy_host_dev
.
begin
()
+
i_r
*
stride
+
n
);
std
::
vector
<
QYDataType
>
qy_host_ref_row
(
qy_host_ref
.
begin
()
+
i_r
*
stride
,
qy_host_ref
.
begin
()
+
i_r
*
stride
+
n
);
std
::
vector
<
QYDataType
>
qy_host_dev_row
(
qy_host_dev
.
begin
()
+
i_r
*
y_stride
,
qy_host_dev
.
begin
()
+
i_r
*
y_stride
+
n
);
std
::
vector
<
QYDataType
>
qy_host_ref_row
(
qy_host_ref
.
begin
()
+
i_r
*
y_stride
,
qy_host_ref
.
begin
()
+
i_r
*
y_stride
+
n
);
pass
&=
ck_tile
::
check_err
(
qy_host_dev_row
,
qy_host_ref_row
,
std
::
string
(
"qy["
)
+
std
::
to_string
(
i_r
)
+
...
...
@@ -210,8 +217,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
std
::
cout
<<
"["
<<
data_type
<<
"]"
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", x_stride:"
<<
x_stride
<<
", y_stride:"
<<
y_stride
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
}
return
pass
;
...
...
example/ck_tile/12_smoothquant/smoothquant.cpp
View file @
d83e2d25
...
...
@@ -33,7 +33,8 @@ auto create_args(int argc, char* argv[])
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"m"
,
"3328"
,
"m dimension"
)
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to n"
)
.
insert
(
"x_stride"
,
"-1"
,
"input stride per row, if -1 then equal to n"
)
.
insert
(
"y_stride"
,
"-1"
,
"output stride per row, if -1 then equal to n"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"prec"
,
"fp16"
,
"precision"
)
...
...
@@ -47,18 +48,21 @@ auto create_args(int argc, char* argv[])
template
<
typename
DataType
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
n
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
if
(
stride
<
0
)
stride
=
n
;
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
n
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
x_stride
=
arg_parser
.
get_int
(
"x_stride"
);
if
(
x_stride
<
0
)
x_stride
=
n
;
ck_tile
::
index_t
y_stride
=
arg_parser
.
get_int
(
"y_stride"
);
if
(
y_stride
<
0
)
y_stride
=
n
;
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
assert
(
stride
>=
n
);
assert
(
x_
stride
>=
n
);
using
TypeConfig
=
SmoothquantTypeConfig
<
DataType
>
;
...
...
@@ -69,14 +73,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
using
ComputeDataType
=
typename
TypeConfig
::
ComputeDataType
;
// host verify
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
x_
stride
,
1
});
ck_tile
::
HostTensor
<
XScaleDataType
>
xscale_host
({
n
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_ref
({
m
},
{
1
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_dev
({
m
},
{
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_ref
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_dev
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_ref
({
m
,
n
},
{
y_
stride
,
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_dev
({
m
,
n
},
{
y_
stride
,
1
});
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
XScaleDataType
>
{
1e-3
,
.5
f
}(
xscale_host
);
...
...
@@ -90,7 +94,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
xscale_buf
.
ToDevice
(
xscale_host
.
data
());
std
::
cout
<<
"["
<<
data_type
<<
"]"
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
std
::
flush
;
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", x_stride:"
<<
x_stride
<<
", y_stride:"
<<
y_stride
<<
std
::
flush
;
smoothquant_traits
traits
{
data_type
};
...
...
@@ -100,7 +105,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf
.
GetDeviceBuffer
(),
m
,
n
,
stride
};
x_stride
,
y_stride
};
float
ave_time
=
smoothquant
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
...
...
@@ -116,7 +122,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
do_validation
)
{
using
YDataType
=
ComputeDataType
;
ck_tile
::
HostTensor
<
ComputeDataType
>
y_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
ComputeDataType
>
y_host
({
m
,
n
},
{
y_
stride
,
1
});
// smooth outlier
{
auto
f
=
[
&
](
auto
n_
)
{
...
...
@@ -166,7 +172,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf
.
FromDevice
(
qy_host_dev
.
data
());
auto
[
rtol
,
atol
]
=
get_elimit
<
QYDataType
>
();
if
(
stride
==
n
)
if
(
y_
stride
==
n
)
{
pass
=
ck_tile
::
check_err
(
qy_host_dev
,
qy_host_ref
,
...
...
@@ -178,10 +184,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
for
(
int
i_r
=
0
;
i_r
<
m
;
i_r
++
)
{
std
::
vector
<
QYDataType
>
qy_host_dev_row
(
qy_host_dev
.
begin
()
+
i_r
*
stride
,
qy_host_dev
.
begin
()
+
i_r
*
stride
+
n
);
std
::
vector
<
QYDataType
>
qy_host_ref_row
(
qy_host_ref
.
begin
()
+
i_r
*
stride
,
qy_host_ref
.
begin
()
+
i_r
*
stride
+
n
);
std
::
vector
<
QYDataType
>
qy_host_dev_row
(
qy_host_dev
.
begin
()
+
i_r
*
y_stride
,
qy_host_dev
.
begin
()
+
i_r
*
y_stride
+
n
);
std
::
vector
<
QYDataType
>
qy_host_ref_row
(
qy_host_ref
.
begin
()
+
i_r
*
y_stride
,
qy_host_ref
.
begin
()
+
i_r
*
y_stride
+
n
);
pass
&=
ck_tile
::
check_err
(
qy_host_dev_row
,
qy_host_ref_row
,
std
::
string
(
"qy["
)
+
std
::
to_string
(
i_r
)
+
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp
View file @
d83e2d25
...
...
@@ -89,7 +89,8 @@ struct DeviceBatchedGemmV2MultiD : public BaseOperator
index_t
BatchStrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
=
0
;
CDEElementwiseOperation
cde_element_op
,
index_t
KBatch
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
View file @
d83e2d25
...
...
@@ -41,12 +41,15 @@ __global__ void
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
g_idx
=
blockIdx
.
z
%
karg
.
Batch
;
const
index_t
k_idx
=
blockIdx
.
z
/
karg
.
Batch
;
const
auto
a_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
auto
b_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
auto
ds_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
const
auto
c_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
);
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
,
k_idx
);
// populate pointer, desc for Ds
static_for
<
0
,
GridwiseGemm
::
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
// D pointer
...
...
@@ -54,8 +57,8 @@ __global__ void
});
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
+
a_batch_offset
,
karg
.
p_b_grid
+
b_batch_offset
,
karg
.
p_a_grid
+
a_batch_offset
+
splitk_batch_offset
.
a_k_split_offset
,
karg
.
p_b_grid
+
b_batch_offset
+
splitk_batch_offset
.
b_k_split_offset
,
karg
.
p_ds_grid
,
karg
.
p_c_grid
+
c_batch_offset
,
p_shared
,
...
...
@@ -87,12 +90,15 @@ __global__ void
__shared__
char
p_shared_1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
g_idx
=
blockIdx
.
z
%
karg
.
Batch
;
const
index_t
k_idx
=
blockIdx
.
z
/
karg
.
Batch
;
const
auto
a_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
auto
b_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
auto
ds_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
const
auto
c_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
);
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
,
k_idx
);
// populate pointer, desc for Ds
static_for
<
0
,
GridwiseGemm
::
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
// D pointer
...
...
@@ -100,8 +106,8 @@ __global__ void
});
GridwiseGemm
::
template
Run_2Lds
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
+
a_batch_offset
,
karg
.
p_b_grid
+
b_batch_offset
,
karg
.
p_a_grid
+
a_batch_offset
+
splitk_batch_offset
.
a_k_split_offset
,
karg
.
p_b_grid
+
b_batch_offset
+
splitk_batch_offset
.
b_k_split_offset
,
karg
.
p_ds_grid
,
karg
.
p_c_grid
+
c_batch_offset
,
p_shared_0
,
...
...
@@ -303,7 +309,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t
Batch_
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
CElementwiseOperation
c_element_op_
)
CElementwiseOperation
c_element_op_
,
index_t
KBatch_
)
:
GridwiseGemm
::
Argument
{
p_a_grid_
,
p_b_grid_
,
p_ds_grid_
,
...
...
@@ -315,7 +322,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
StrideB_
,
StrideDs_
,
StrideE_
,
1
,
KBatch_
,
a_element_op_
,
b_element_op_
,
c_element_op_
},
...
...
@@ -336,13 +343,14 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
arg
.
Print
();
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
)
||
arg
.
KBatch
>
1
)
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
M
,
arg
.
N
,
arg
.
Batch
);
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
M
,
arg
.
N
,
arg
.
Batch
*
arg
.
KBatch
);
float
ave_time
=
0
;
...
...
@@ -387,10 +395,11 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
rotating_mem
.
Next
();
// clear c mem
if
(
arg_
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg_
.
p_c_grid
,
0
,
arg_
.
M
*
arg_
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
hipGetErrorString
(
hipMemsetAsync
(
arg_
.
p_c_grid
,
0
,
arg
.
Batch
*
arg_
.
M
*
arg_
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
};
ave_time
=
ck
::
utility
::
launch_and_time_kernel_with_preprocess
<
false
>
(
...
...
@@ -889,7 +898,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t
BatchStrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
,
index_t
KBatch
=
1
)
{
return
Argument
{
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
...
...
@@ -909,7 +919,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
Batch
,
a_element_op
,
b_element_op
,
c_element_op
};
c_element_op
,
KBatch
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -934,7 +945,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t
BatchStrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
override
CElementwiseOperation
c_element_op
,
index_t
KBatch
=
1
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
...
...
@@ -954,7 +966,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
Batch
,
a_element_op
,
b_element_op
,
c_element_op
);
c_element_op
,
KBatch
);
}
// polymorphic
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
View file @
d83e2d25
...
...
@@ -41,7 +41,7 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
);
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
,
blockIdx
.
z
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
+
splitk_batch_offset
.
a_k_split_offset
,
...
...
@@ -76,7 +76,7 @@ __global__ void
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
);
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
,
blockIdx
.
z
);
GridwiseGemm
::
template
Run_2Lds
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
+
splitk_batch_offset
.
a_k_split_offset
,
...
...
@@ -639,27 +639,27 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
struct
SplitKBatchOffset
{
__device__
SplitKBatchOffset
(
Argument
&
karg
)
__device__
SplitKBatchOffset
(
Argument
&
karg
,
index_t
k_id
)
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
;
a_k_split_offset
=
k_id
*
karg
.
KRead
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
StrideA
;
a_k_split_offset
=
k_id
*
karg
.
KRead
*
karg
.
StrideA
;
}
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
{
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
StrideB
;
b_k_split_offset
=
k_id
*
karg
.
KRead
*
karg
.
StrideB
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
;
b_k_split_offset
=
k_id
*
karg
.
KRead
;
}
if
(
blockIdx
.
z
<
static_cast
<
uint32_t
>
(
karg
.
KBatch
-
1
)
)
if
(
k_id
<
karg
.
KBatch
-
1
)
{
karg
.
K
=
karg
.
KRead
;
}
...
...
include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp
View file @
d83e2d25
...
...
@@ -19,7 +19,8 @@ struct SmoothquantHostArgs
index_t
m
;
index_t
n
;
index_t
stride
;
// row_stride
index_t
x_stride
;
// input row_stride
index_t
y_stride
;
// output row_stride
};
// TODO: Extract some type to wrapper class
...
...
@@ -58,14 +59,21 @@ struct Smoothquant
index_t
m
;
index_t
n
;
index_t
stride
;
// row_stride
index_t
x_stride
;
// input row_stride
index_t
y_stride
;
// out row_stride
};
using
Hargs
=
SmoothquantHostArgs
;
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
return
Kargs
{
hargs
.
p_x
,
hargs
.
p_xscale
,
hargs
.
p_yscale
,
hargs
.
p_qy
,
hargs
.
m
,
hargs
.
n
,
hargs
.
stride
};
return
Kargs
{
hargs
.
p_x
,
hargs
.
p_xscale
,
hargs
.
p_yscale
,
hargs
.
p_qy
,
hargs
.
m
,
hargs
.
n
,
hargs
.
x_stride
,
hargs
.
y_stride
};
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
...
...
@@ -116,7 +124,7 @@ struct Smoothquant
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
make_tuple
(
kargs
.
x_
stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
...
...
@@ -157,7 +165,7 @@ struct Smoothquant
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
QYDataType
*>
(
kargs
.
p_qy
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
make_tuple
(
kargs
.
y_
stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
...
...
library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp
View file @
d83e2d25
...
...
@@ -52,6 +52,9 @@ using device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances =
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
224
,
256
,
64
,
8
,
8
,
16
,
16
,
7
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
2
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
224
,
64
,
8
,
8
,
16
,
16
,
8
,
7
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
2
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
160
,
64
,
8
,
8
,
16
,
16
,
8
,
5
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
2
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
160
,
64
,
8
,
8
,
32
,
32
,
1
,
5
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
160
,
128
,
64
,
8
,
8
,
32
,
32
,
5
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
>
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
...
...
library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_f8_f8_bf16/device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp
View file @
d83e2d25
...
...
@@ -42,6 +42,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std
//##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#ifdef __gfx94__
// Compute friendly
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
F8
,
F8
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
...
...
@@ -72,6 +73,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std:
//##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
F8
,
F8
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
DsLayout
,
Row
,
F8
,
F8
,
DsDataType
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
S
<
4
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
...
...
profiler/include/profiler/profile_gemm_universal_batched_impl.hpp
View file @
d83e2d25
...
...
@@ -48,6 +48,7 @@ bool profile_gemm_universal_batched_impl(int do_verification,
int
StrideB
,
int
StrideC
,
int
BatchCount
,
int
KBatch
,
int
n_warmup
,
int
n_iter
,
uint64_t
rotating
=
0
)
...
...
@@ -147,89 +148,100 @@ bool profile_gemm_universal_batched_impl(int do_verification,
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
float
best_kbatch
=
0
;
// profile device op instances
for
(
auto
&
op_ptr
:
op_ptrs
)
{
std
::
unique_ptr
<
tensor_operation
::
device
::
BaseArgument
>
argument_ptr
;
// false branch for multi d dl kernel
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
{},
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
BatchCount
,
StrideA
,
StrideB
,
{},
StrideC
,
BatchStrideA
,
BatchStrideB
,
{},
BatchStrideC
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
// re-init C to zero before profiling next kernel
c_device_buf
.
SetZero
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
std
::
vector
<
int
>
kbatch_list
=
{
1
,
2
,
4
,
8
,
16
,
19
,
32
,
38
};
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
,
0
,
n_warmup
,
n_iter
,
true
,
rotating_count
});
if
(
KBatch
>
0
)
{
kbatch_list
=
{
KBatch
};
}
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
BatchCount
*
M
*
N
*
K
;
for
(
std
::
size_t
i
=
0
;
i
<
kbatch_list
.
size
();
i
++
)
{
auto
kbatch_curr
=
kbatch_list
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
{},
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
BatchCount
,
StrideA
,
StrideB
,
{},
StrideC
,
BatchStrideA
,
BatchStrideB
,
{},
BatchStrideC
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
kbatch_curr
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
)
*
BatchC
ount
;
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
,
0
,
n_warmup
,
n_iter
,
true
,
rotating_c
ount
})
;
float
t
flop
s
=
st
atic_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
std
::
size_t
flop
=
st
d
::
size_t
(
2
)
*
BatchCount
*
M
*
N
*
K
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
)
*
BatchCount
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
if
(
tflops
>
best_tflops
)
{
best_op_name
=
op_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
if
(
do_verification
)
{
c_device_buf
.
FromDevice
(
c_g_m_n_device_result
.
mData
.
data
());
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
", KBatch "
<<
kbatch_curr
<<
std
::
endl
;
pass
=
pass
&
ck
::
utils
::
check_err
(
c_g_m_n_device_result
,
c_g_m_n_host_result
);
if
(
tflops
>
best_tflops
)
{
best_op_name
=
op_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
best_kbatch
=
kbatch_curr
;
}
if
(
do_
log
)
if
(
do_
verification
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"a : "
,
a_g_m_k
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b: "
,
b_g_k_n
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_host: "
,
c_g_m_n_host_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_device: "
,
c_g_m_n_device_result
.
mData
,
","
)
<<
std
::
endl
;
c_device_buf
.
FromDevice
(
c_g_m_n_device_result
.
mData
.
data
());
pass
=
pass
&
ck
::
utils
::
check_err
(
c_g_m_n_device_result
,
c_g_m_n_host_result
);
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"a : "
,
a_g_m_k
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b: "
,
b_g_k_n
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_host: "
,
c_g_m_n_host_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_device: "
,
c_g_m_n_device_result
.
mData
,
","
)
<<
std
::
endl
;
}
}
}
}
else
{
std
::
cout
<<
op_ptr
->
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
else
{
std
::
cout
<<
op_ptr
->
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
}
}
}
...
...
@@ -270,8 +282,8 @@ bool profile_gemm_universal_batched_impl(int do_verification,
std
::
cout
<<
" B = "
<<
BatchCount
<<
" M = "
<<
M
<<
" N = "
<<
N
<<
" K = "
<<
K
<<
" StrideA = "
<<
StrideA
<<
" StrideB = "
<<
StrideB
<<
" StrideC = "
<<
StrideC
<<
": "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
<<
" KBatch = "
<<
best_kbatch
<<
": "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
"
TFlops, "
<<
best_gb_per_sec
<<
"
GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
pass
;
}
...
...
profiler/src/profile_gemm_universal_batched.cpp
View file @
d83e2d25
...
...
@@ -31,7 +31,7 @@ enum struct GemmDataType
int
profile_batched_gemm_universal
(
int
argc
,
char
*
argv
[])
{
if
(
argc
!=
1
8
&&
argc
!=
2
1
)
if
(
argc
!=
1
9
&&
argc
!=
2
2
)
{
// clang-format off
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
...
...
@@ -44,11 +44,11 @@ int profile_batched_gemm_universal(int argc, char* argv[])
printf
(
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg6: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7: time kernel (0=n0, 1=yes)
\n
"
);
printf
(
"arg8 to 1
7
: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount
\n
"
);
printf
(
"arg8 to 1
8
: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount
, KBatch
\n
"
);
printf
(
"optional:
\n
"
);
printf
(
"arg1
8
: number of warm-up cycles (default 1)
\n
"
);
printf
(
"arg
19
: number of iterations (default 10)
\n
"
);
printf
(
"arg2
0
: memory for rotating buffer (default 0, size in MB)
\n
"
);
printf
(
"arg1
9
: number of warm-up cycles (default 1)
\n
"
);
printf
(
"arg
20
: number of iterations (default 10)
\n
"
);
printf
(
"arg2
1
: memory for rotating buffer (default 0, size in MB)
\n
"
);
// clang-format on
exit
(
1
);
}
...
...
@@ -56,11 +56,11 @@ int profile_batched_gemm_universal(int argc, char* argv[])
int
n_warmup
=
1
;
int
n_iter
=
10
;
uint64_t
rotating
=
0
;
if
(
argc
==
2
1
)
if
(
argc
==
2
2
)
{
n_warmup
=
std
::
stoi
(
argv
[
1
8
]);
n_iter
=
std
::
stoi
(
argv
[
19
]);
rotating
=
std
::
stoull
(
argv
[
2
0
])
*
1024
*
1024
;
n_warmup
=
std
::
stoi
(
argv
[
1
9
]);
n_iter
=
std
::
stoi
(
argv
[
20
]);
rotating
=
std
::
stoull
(
argv
[
2
1
])
*
1024
*
1024
;
}
const
auto
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
...
...
@@ -83,6 +83,7 @@ int profile_batched_gemm_universal(int argc, char* argv[])
const
int
BatchStrideC
=
std
::
stoi
(
argv
[
16
]);
const
int
BatchCount
=
std
::
stoi
(
argv
[
17
]);
const
int
KBatch
=
std
::
stoi
(
argv
[
18
]);
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
using
F8
=
ck
::
f8_t
;
...
...
@@ -159,6 +160,7 @@ int profile_batched_gemm_universal(int argc, char* argv[])
StrideB_
,
StrideC_
,
BatchCount
,
KBatch
,
n_warmup
,
n_iter
,
rotating
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment