Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e4e99a49
Commit
e4e99a49
authored
Sep 22, 2022
by
Po-Yen, Chen
Browse files
Use new utilities to shorten codes
parent
7acbf104
Changes
144
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
690 additions
and
870 deletions
+690
-870
example/24_batched_gemm/common.hpp
example/24_batched_gemm/common.hpp
+35
-0
example/24_batched_gemm/run_batched_gemm_example.inc
example/24_batched_gemm/run_batched_gemm_example.inc
+20
-21
example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp
..._bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp
+56
-69
example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp
..._bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp
+52
-66
example/26_contraction/contraction_bilinear_xdl_fp32.cpp
example/26_contraction/contraction_bilinear_xdl_fp32.cpp
+46
-60
example/26_contraction/contraction_scale_xdl_fp32.cpp
example/26_contraction/contraction_scale_xdl_fp32.cpp
+42
-54
example/27_layernorm/layernorm_blockwise.cpp
example/27_layernorm/layernorm_blockwise.cpp
+34
-31
example/28_grouped_gemm_bias_e_permute/grouped_gemm_bias_e_permute_xdl_fp16.cpp
...m_bias_e_permute/grouped_gemm_bias_e_permute_xdl_fp16.cpp
+53
-68
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp
...m_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp
+57
-71
example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_common.hpp
...bias_relu_add/grouped_convnd_fwd_bias_relu_add_common.hpp
+53
-76
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_bf16.cpp
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_bf16.cpp
+1
-26
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp
+1
-26
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp32.cpp
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp32.cpp
+1
-25
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp
+1
-23
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int8.cpp
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int8.cpp
+1
-23
example/31_batched_gemm_gemm/common.hpp
example/31_batched_gemm_gemm/common.hpp
+33
-0
example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc
...le/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc
+50
-60
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
...mm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
+53
-61
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
...gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
+53
-61
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp
...softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp
+48
-49
No files found.
example/24_batched_gemm/common.hpp
0 → 100644
View file @
e4e99a49
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <random>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
BF16
=
ck
::
bhalf_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
example/24_batched_gemm/run_batched_gemm_example.inc
View file @
e4e99a49
#include <random>
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -28,8 +29,6 @@ struct ExecutionConfig final
bool
run_batched_gemm
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
{
using
namespace
ck
::
literals
;
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert
(
sizeof
(
ck
::
int4_t
)
==
sizeof
(
int8_t
));
static_assert
(
sizeof
(
ADataType
)
==
sizeof
(
KernelADataType
));
...
...
@@ -48,6 +47,8 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
batch_stride_C
,
batch_count
]
=
problem_size
;
using
namespace
ck
::
literals
;
// GEMM shape
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
batch_count_
,
std
::
size_t
row
,
...
...
@@ -55,15 +56,13 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
std
::
size_t
stride
,
std
::
size_t
batch_stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
::
value
)
if
constexpr
(
std
::
is_same
_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count_
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
batch_stride
,
stride
,
1
}));
return
HostTensorDescriptor
({
batch_count_
,
row
,
col
},
{
batch_stride
,
stride
,
1_
uz
});
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count_
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
batch_stride
,
1
,
stride
}));
return
HostTensorDescriptor
({
batch_count_
,
row
,
col
},
{
batch_stride
,
1_
uz
,
stride
});
}
};
...
...
@@ -79,9 +78,9 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
batch_stride_C
,
ELayout
{}));
#endif
std
::
cout
<<
"a_g_m_k: "
<<
a_g_m_k
.
m
Desc
<<
std
::
endl
;
std
::
cout
<<
"b_g_k_n: "
<<
b_g_k_n
.
m
Desc
<<
std
::
endl
;
std
::
cout
<<
"e_g_m_n: "
<<
e_g_m_n_device_result
.
m
Desc
<<
std
::
endl
;
std
::
cout
<<
"a_g_m_k: "
<<
a_g_m_k
.
Get
Desc
()
<<
std
::
endl
;
std
::
cout
<<
"b_g_k_n: "
<<
b_g_k_n
.
Get
Desc
()
<<
std
::
endl
;
std
::
cout
<<
"e_g_m_n: "
<<
e_g_m_n_device_result
.
Get
Desc
()
<<
std
::
endl
;
switch
(
config
.
init_method
)
{
...
...
@@ -96,19 +95,19 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
break
;
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_g_m_k
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_g_k_n
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
c_device_buf
(
sizeof
(
EDataType
)
*
e_g_m_n_device_result
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
a_device_buf
(
a_g_m_k
.
GetMemory
Size
());
DeviceMem
b_device_buf
(
b_g_k_n
.
GetMemory
Size
());
DeviceMem
c_device_buf
(
e_g_m_n_device_result
.
GetMemory
Size
());
#ifdef BUILD_INT4_EXAMPLE
const
Tensor
<
KernelADataType
>
a_g_m_k_converted
(
a_g_m_k
);
const
Tensor
<
KernelBDataType
>
b_g_k_n_converted
(
b_g_k_n
);
a_device_buf
.
ToDevice
(
a_g_m_k_converted
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_g_k_n_converted
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_g_m_k_converted
.
data
());
b_device_buf
.
ToDevice
(
b_g_k_n_converted
.
data
());
#else
a_device_buf
.
ToDevice
(
a_g_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_g_k_n
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_g_m_k
.
data
());
b_device_buf
.
ToDevice
(
b_g_k_n
.
data
());
#endif
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
...
...
@@ -150,7 +149,7 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
if
(
config
.
do_verification
)
{
c_device_buf
.
FromDevice
(
e_g_m_n_device_result
.
mData
.
data
());
c_device_buf
.
FromDevice
(
e_g_m_n_device_result
.
data
());
using
ReferenceBatchedGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
...
...
@@ -174,11 +173,11 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
#ifdef BUILD_INT4_EXAMPLE
const
Tensor
<
EDataType
>
e_device_result_converted
(
e_g_m_n_device_result
);
pass
&=
ck
::
utils
::
check_err
(
e_device_result_converted
.
mData
,
e_g_m_n_host_result
.
mData
);
pass
&=
ck
::
utils
::
check_err
(
e_device_result_converted
,
e_g_m_n_host_result
);
#else
pass
=
ck
::
utils
::
check_err
(
e_g_m_n_device_result
.
mData
,
e_g_m_n_host_result
.
mData
,
"Error: Incorrect results c"
);
e_g_m_n_device_result
,
e_g_m_n_host_result
,
"Error: Incorrect results c"
);
#endif
}
...
...
example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp
View file @
e4e99a49
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/array.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/numeric.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
@@ -110,7 +111,7 @@ struct ReferenceContraction_G1_M2_N3_K1 : public ck::tensor_operation::device::B
float
Run
(
const
Argument
&
arg
)
{
auto
f_gs_ms_ns
=
[
&
](
auto
g0
,
auto
m0
,
auto
m1
,
auto
n0
,
auto
n1
,
auto
n2
)
{
const
int
K0
=
arg
.
a_gs_ms_ks_
.
mDesc
.
GetLengths
()[
3
];
const
int
K0
=
arg
.
a_gs_ms_ks_
.
GetLengths
()[
3
];
AccDataType
v_acc
=
0
;
...
...
@@ -136,12 +137,12 @@ struct ReferenceContraction_G1_M2_N3_K1 : public ck::tensor_operation::device::B
};
make_ParallelTensorFunctor
(
f_gs_ms_ns
,
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
0
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
1
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
2
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
3
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
4
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
5
])(
arg
.
e_gs_ms_ns_
.
GetLengths
()[
0
],
arg
.
e_gs_ms_ns_
.
GetLengths
()[
1
],
arg
.
e_gs_ms_ns_
.
GetLengths
()[
2
],
arg
.
e_gs_ms_ns_
.
GetLengths
()[
3
],
arg
.
e_gs_ms_ns_
.
GetLengths
()[
4
],
arg
.
e_gs_ms_ns_
.
GetLengths
()[
5
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
...
...
@@ -246,26 +247,16 @@ int main(int argc, char* argv[])
exit
(
0
);
}
Tensor
<
ADataType
>
a_gs_ms_ks
(
std
::
vector
<
std
::
size_t
>
(
a_gs_ms_ks_lengths
.
begin
(),
a_gs_ms_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
a_gs_ms_ks_strides
.
begin
(),
a_gs_ms_ks_strides
.
end
()));
Tensor
<
BDataType
>
b_gs_ns_ks
(
std
::
vector
<
std
::
size_t
>
(
b_gs_ns_ks_lengths
.
begin
(),
b_gs_ns_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
b_gs_ns_ks_strides
.
begin
(),
b_gs_ns_ks_strides
.
end
()));
Tensor
<
DDataType
>
d_gs_ms_ns
(
std
::
vector
<
std
::
size_t
>
(
d_gs_ms_ns_lengths
.
begin
(),
d_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
d_gs_ms_ns_strides
.
begin
(),
d_gs_ms_ns_strides
.
end
()));
Tensor
<
EDataType
>
e_gs_ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_lengths
.
begin
(),
e_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_strides
.
begin
(),
e_gs_ms_ns_strides
.
end
()));
Tensor
<
EDataType
>
e_gs_ms_ns_device_result
(
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_lengths
.
begin
(),
e_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_strides
.
begin
(),
e_gs_ms_ns_strides
.
end
()));
std
::
cout
<<
"a_gs_ms_ks: "
<<
a_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_gs_ns_ks: "
<<
b_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_gs_ms_ns: "
<<
d_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_gs_ms_ns: "
<<
e_gs_ms_ns_host_result
.
mDesc
<<
std
::
endl
;
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
BDataType
>
b_gs_ns_ks
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
);
Tensor
<
DDataType
>
d_gs_ms_ns
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
Tensor
<
EDataType
>
e_gs_ms_ns_host_result
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
);
Tensor
<
EDataType
>
e_gs_ms_ns_device_result
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
);
std
::
cout
<<
"a_gs_ms_ks: "
<<
a_gs_ms_ks
.
GetDesc
()
<<
std
::
endl
;
std
::
cout
<<
"b_gs_ns_ks: "
<<
b_gs_ns_ks
.
GetDesc
()
<<
std
::
endl
;
std
::
cout
<<
"d_gs_ms_ns: "
<<
d_gs_ms_ns
.
GetDesc
()
<<
std
::
endl
;
std
::
cout
<<
"e_gs_ms_ns: "
<<
e_gs_ms_ns_host_result
.
GetDesc
()
<<
std
::
endl
;
switch
(
init_method
)
{
...
...
@@ -282,15 +273,14 @@ int main(int argc, char* argv[])
break
;
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_gs_ms_ns_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_device_buf
(
a_gs_ms_ks
.
GetMemorySize
());
DeviceMem
b_device_buf
(
b_gs_ns_ks
.
GetMemorySize
());
DeviceMem
d_device_buf
(
d_gs_ms_ns
.
GetMemorySize
());
DeviceMem
e_device_buf
(
e_gs_ms_ns_device_result
.
GetMemorySize
());
a_device_buf
.
ToDevice
(
a_gs_ms_ks
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_gs_ns_ks
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_gs_ms_ns
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_gs_ms_ks
.
data
());
b_device_buf
.
ToDevice
(
b_gs_ns_ks
.
data
());
d_device_buf
.
ToDevice
(
d_gs_ms_ns
.
data
());
// set zero
e_device_buf
.
SetZero
();
...
...
@@ -299,19 +289,21 @@ int main(int argc, char* argv[])
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
using
ck
::
utils
::
to_array
;
// device operation
auto
op
=
DeviceOpInstance
{};
auto
invoker
=
op
.
MakeInvoker
();
auto
argument
=
op
.
MakeArgument
(
a_device_buf
.
GetDeviceBuffer
(),
b_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d_device_buf
.
GetDeviceBuffer
()},
to_array
(
{
d_device_buf
.
GetDeviceBuffer
()}
)
,
e_device_buf
.
GetDeviceBuffer
(),
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_gs_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_gs_ms_ns_strides
},
to_array
(
{
d_gs_ms_ns_lengths
}
)
,
to_array
(
{
d_gs_ms_ns_strides
}
)
,
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
,
a_element_op
,
...
...
@@ -327,20 +319,20 @@ int main(int argc, char* argv[])
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
M
=
std
::
accumulate
(
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
,
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
std
::
size_t
M
=
ck
::
accumulate
_n
(
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
,
NumDimM
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
std
::
size_t
N
=
std
::
accumulate
(
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
+
NumDimM
+
NumDimN
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
std
::
size_t
N
=
ck
::
accumulate
_n
(
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
NumDimN
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
std
::
size_t
K
=
std
::
accumulate
(
a_gs_ms_ks_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
a_gs_ms_ks_lengths
.
begin
()
+
NumDimG
+
NumDimM
+
NumDimK
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
std
::
size_t
K
=
ck
::
accumulate
_n
(
a_gs_ms_ks_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
NumDimK
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
...
...
@@ -353,13 +345,11 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op
.
GetTypeString
()
<<
std
::
endl
;
e_device_buf
.
FromDevice
(
e_gs_ms_ns_device_result
.
mData
.
data
());
e_device_buf
.
FromDevice
(
e_gs_ms_ns_device_result
.
data
());
if
(
do_verification
)
{
Tensor
<
CShuffleDataType
>
c_gs_ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_lengths
.
begin
(),
e_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_strides
.
begin
(),
e_gs_ms_ns_strides
.
end
()));
Tensor
<
CShuffleDataType
>
c_gs_ms_ns_host_result
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
);
using
ReferenceOpInstance
=
ReferenceContraction_G1_M2_N3_K1
<
NumDimM
,
NumDimN
,
...
...
@@ -384,18 +374,17 @@ int main(int argc, char* argv[])
ref_invoker
.
Run
(
ref_argument
);
for
(
size_t
g0
=
0
;
g0
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
0
];
++
g0
)
for
(
size_t
g0
=
0
;
g0
<
e_gs_ms_ns_host_result
.
GetLengths
()[
0
];
++
g0
)
{
for
(
size_t
m0
=
0
;
m0
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
1
];
++
m0
)
for
(
size_t
m0
=
0
;
m0
<
e_gs_ms_ns_host_result
.
GetLengths
()[
1
];
++
m0
)
{
for
(
size_t
m1
=
0
;
m1
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
2
];
++
m1
)
for
(
size_t
m1
=
0
;
m1
<
e_gs_ms_ns_host_result
.
GetLengths
()[
2
];
++
m1
)
{
for
(
size_t
n0
=
0
;
n0
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
3
];
++
n0
)
for
(
size_t
n0
=
0
;
n0
<
e_gs_ms_ns_host_result
.
GetLengths
()[
3
];
++
n0
)
{
for
(
size_t
n1
=
0
;
n1
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
4
];
++
n1
)
for
(
size_t
n1
=
0
;
n1
<
e_gs_ms_ns_host_result
.
GetLengths
()[
4
];
++
n1
)
{
for
(
size_t
n2
=
0
;
n2
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
5
];
++
n2
)
for
(
size_t
n2
=
0
;
n2
<
e_gs_ms_ns_host_result
.
GetLengths
()[
5
];
++
n2
)
{
cde_element_op
(
e_gs_ms_ns_host_result
(
g0
,
m0
,
m1
,
n0
,
n1
,
n2
),
c_gs_ms_ns_host_result
(
g0
,
m0
,
m1
,
n0
,
n1
,
n2
),
...
...
@@ -407,9 +396,7 @@ int main(int argc, char* argv[])
}
}
return
ck
::
utils
::
check_err
(
e_gs_ms_ns_device_result
.
mData
,
e_gs_ms_ns_host_result
.
mData
)
?
0
:
1
;
return
ck
::
utils
::
check_err
(
e_gs_ms_ns_device_result
,
e_gs_ms_ns_host_result
)
?
0
:
1
;
}
return
0
;
...
...
example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp
View file @
e4e99a49
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/array.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
@@ -108,7 +110,7 @@ struct ReferenceContraction_G1_M3_N2_K1 : public ck::tensor_operation::device::B
float
Run
(
const
Argument
&
arg
)
{
auto
f_gs_ms_ns
=
[
&
](
auto
g0
,
auto
m0
,
auto
m1
,
auto
m2
,
auto
n0
,
auto
n1
)
{
const
int
K0
=
arg
.
a_gs_ms_ks_
.
mDesc
.
GetLengths
()[
4
];
const
int
K0
=
arg
.
a_gs_ms_ks_
.
GetLengths
()[
4
];
AccDataType
v_acc
=
0
;
...
...
@@ -134,12 +136,12 @@ struct ReferenceContraction_G1_M3_N2_K1 : public ck::tensor_operation::device::B
};
make_ParallelTensorFunctor
(
f_gs_ms_ns
,
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
0
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
1
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
2
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
3
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
4
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
5
])(
arg
.
e_gs_ms_ns_
.
GetLengths
()[
0
],
arg
.
e_gs_ms_ns_
.
GetLengths
()[
1
],
arg
.
e_gs_ms_ns_
.
GetLengths
()[
2
],
arg
.
e_gs_ms_ns_
.
GetLengths
()[
3
],
arg
.
e_gs_ms_ns_
.
GetLengths
()[
4
],
arg
.
e_gs_ms_ns_
.
GetLengths
()[
5
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
...
...
@@ -246,26 +248,16 @@ int main(int argc, char* argv[])
exit
(
0
);
}
Tensor
<
ADataType
>
a_gs_ms_ks
(
std
::
vector
<
std
::
size_t
>
(
a_gs_ms_ks_lengths
.
begin
(),
a_gs_ms_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
a_gs_ms_ks_strides
.
begin
(),
a_gs_ms_ks_strides
.
end
()));
Tensor
<
BDataType
>
b_gs_ns_ks
(
std
::
vector
<
std
::
size_t
>
(
b_gs_ns_ks_lengths
.
begin
(),
b_gs_ns_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
b_gs_ns_ks_strides
.
begin
(),
b_gs_ns_ks_strides
.
end
()));
Tensor
<
DDataType
>
d_gs_ms_ns
(
std
::
vector
<
std
::
size_t
>
(
d_gs_ms_ns_lengths
.
begin
(),
d_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
d_gs_ms_ns_strides
.
begin
(),
d_gs_ms_ns_strides
.
end
()));
Tensor
<
EDataType
>
e_gs_ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_lengths
.
begin
(),
e_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_strides
.
begin
(),
e_gs_ms_ns_strides
.
end
()));
Tensor
<
EDataType
>
e_gs_ms_ns_device_result
(
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_lengths
.
begin
(),
e_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_strides
.
begin
(),
e_gs_ms_ns_strides
.
end
()));
std
::
cout
<<
"a_gs_ms_ks: "
<<
a_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_gs_ns_ks: "
<<
b_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_gs_ms_ns: "
<<
d_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_gs_ms_ns: "
<<
e_gs_ms_ns_host_result
.
mDesc
<<
std
::
endl
;
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
BDataType
>
b_gs_ns_ks
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
);
Tensor
<
DDataType
>
d_gs_ms_ns
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
Tensor
<
EDataType
>
e_gs_ms_ns_host_result
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
);
Tensor
<
EDataType
>
e_gs_ms_ns_device_result
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
);
std
::
cout
<<
"a_gs_ms_ks: "
<<
a_gs_ms_ks
.
GetDesc
()
<<
std
::
endl
;
std
::
cout
<<
"b_gs_ns_ks: "
<<
b_gs_ns_ks
.
GetDesc
()
<<
std
::
endl
;
std
::
cout
<<
"d_gs_ms_ns: "
<<
d_gs_ms_ns
.
GetDesc
()
<<
std
::
endl
;
std
::
cout
<<
"e_gs_ms_ns: "
<<
e_gs_ms_ns_host_result
.
GetDesc
()
<<
std
::
endl
;
switch
(
init_method
)
{
...
...
@@ -282,15 +274,14 @@ int main(int argc, char* argv[])
break
;
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_gs_ms_ns_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_device_buf
(
a_gs_ms_ks
.
GetMemorySize
());
DeviceMem
b_device_buf
(
b_gs_ns_ks
.
GetMemorySize
());
DeviceMem
d_device_buf
(
d_gs_ms_ns
.
GetMemorySize
());
DeviceMem
e_device_buf
(
e_gs_ms_ns_device_result
.
GetMemorySize
());
a_device_buf
.
ToDevice
(
a_gs_ms_ks
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_gs_ns_ks
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_gs_ms_ns
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_gs_ms_ks
.
data
());
b_device_buf
.
ToDevice
(
b_gs_ns_ks
.
data
());
d_device_buf
.
ToDevice
(
d_gs_ms_ns
.
data
());
// set zero
e_device_buf
.
SetZero
();
...
...
@@ -299,19 +290,21 @@ int main(int argc, char* argv[])
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
using
ck
::
utils
::
to_array
;
// device operation
auto
op
=
DeviceOpInstance
{};
auto
invoker
=
op
.
MakeInvoker
();
auto
argument
=
op
.
MakeArgument
(
a_device_buf
.
GetDeviceBuffer
(),
b_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d_device_buf
.
GetDeviceBuffer
()},
to_array
(
{
d_device_buf
.
GetDeviceBuffer
()}
)
,
e_device_buf
.
GetDeviceBuffer
(),
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_gs_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_gs_ms_ns_strides
},
to_array
(
{
d_gs_ms_ns_lengths
}
)
,
to_array
(
{
d_gs_ms_ns_strides
}
)
,
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
,
a_element_op
,
...
...
@@ -327,20 +320,18 @@ int main(int argc, char* argv[])
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
ck
::
index_t
M
=
std
::
accumulate
(
e_gs_ms_ns_lengths
.
begin
(),
e_gs_ms_ns_lengths
.
begin
()
+
NumDimM
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
M
=
ck
::
accumulate_n
(
e_gs_ms_ns_lengths
.
begin
(),
NumDimM
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
N
=
std
::
accumulate
(
e_gs_ms_ns_lengths
.
begin
()
+
NumDimM
,
e_gs_ms_ns_lengths
.
begin
()
+
NumDimM
+
NumDimN
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
N
=
ck
::
accumulate
_n
(
e_gs_ms_ns_lengths
.
begin
()
+
NumDimM
,
NumDimN
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
K
=
std
::
accumulate
(
a_gs_ms_ks_lengths
.
begin
()
+
NumDimM
,
a_gs_ms_ks_lengths
.
begin
()
+
NumDimM
+
NumDimK
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
K
=
ck
::
accumulate
_n
(
a_gs_ms_ks_lengths
.
begin
()
+
NumDimM
,
NumDimK
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
...
...
@@ -353,13 +344,11 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op
.
GetTypeString
()
<<
std
::
endl
;
e_device_buf
.
FromDevice
(
e_gs_ms_ns_device_result
.
mData
.
data
());
e_device_buf
.
FromDevice
(
e_gs_ms_ns_device_result
.
data
());
if
(
do_verification
)
{
Tensor
<
CShuffleDataType
>
c_gs_ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_lengths
.
begin
(),
e_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_strides
.
begin
(),
e_gs_ms_ns_strides
.
end
()));
Tensor
<
CShuffleDataType
>
c_gs_ms_ns_host_result
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
);
using
ReferenceOpInstance
=
ReferenceContraction_G1_M3_N2_K1
<
NumDimG
,
NumDimM
,
...
...
@@ -385,18 +374,17 @@ int main(int argc, char* argv[])
ref_invoker
.
Run
(
ref_argument
);
for
(
size_t
g0
=
0
;
g0
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
0
];
++
g0
)
for
(
size_t
g0
=
0
;
g0
<
e_gs_ms_ns_host_result
.
GetLengths
()[
0
];
++
g0
)
{
for
(
size_t
m0
=
0
;
m0
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
1
];
++
m0
)
for
(
size_t
m0
=
0
;
m0
<
e_gs_ms_ns_host_result
.
GetLengths
()[
1
];
++
m0
)
{
for
(
size_t
m1
=
0
;
m1
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
2
];
++
m1
)
for
(
size_t
m1
=
0
;
m1
<
e_gs_ms_ns_host_result
.
GetLengths
()[
2
];
++
m1
)
{
for
(
size_t
m2
=
0
;
m2
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
3
];
++
m2
)
for
(
size_t
m2
=
0
;
m2
<
e_gs_ms_ns_host_result
.
GetLengths
()[
3
];
++
m2
)
{
for
(
size_t
n0
=
0
;
n0
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
4
];
++
n0
)
for
(
size_t
n0
=
0
;
n0
<
e_gs_ms_ns_host_result
.
GetLengths
()[
4
];
++
n0
)
{
for
(
size_t
n1
=
0
;
n1
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
5
];
++
n1
)
for
(
size_t
n1
=
0
;
n1
<
e_gs_ms_ns_host_result
.
GetLengths
()[
5
];
++
n1
)
{
cde_element_op
(
e_gs_ms_ns_host_result
(
g0
,
m0
,
m1
,
m2
,
n0
,
n1
),
c_gs_ms_ns_host_result
(
g0
,
m0
,
m1
,
m2
,
n0
,
n1
),
...
...
@@ -408,9 +396,7 @@ int main(int argc, char* argv[])
}
}
return
ck
::
utils
::
check_err
(
e_gs_ms_ns_device_result
.
mData
,
e_gs_ms_ns_host_result
.
mData
)
?
0
:
1
;
return
ck
::
utils
::
check_err
(
e_gs_ms_ns_device_result
,
e_gs_ms_ns_host_result
)
?
0
:
1
;
}
return
0
;
...
...
example/26_contraction/contraction_bilinear_xdl_fp32.cpp
View file @
e4e99a49
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/array.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
@@ -122,8 +124,8 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
float
Run
(
const
Argument
&
arg
)
{
auto
f_ms_ns
=
[
&
](
auto
m0
,
auto
m1
,
auto
n0
,
auto
n1
)
{
const
int
K0
=
arg
.
a_ms_ks_
.
mDesc
.
GetLengths
()[
2
];
const
int
K1
=
arg
.
a_ms_ks_
.
mDesc
.
GetLengths
()[
3
];
const
int
K0
=
arg
.
a_ms_ks_
.
GetLengths
()[
2
];
const
int
K1
=
arg
.
a_ms_ks_
.
GetLengths
()[
3
];
AccDataType
v_acc
=
0
;
...
...
@@ -151,10 +153,10 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
};
make_ParallelTensorFunctor
(
f_ms_ns
,
arg
.
e_ms_ns_
.
mDesc
.
GetLengths
()[
0
],
arg
.
e_ms_ns_
.
mDesc
.
GetLengths
()[
1
],
arg
.
e_ms_ns_
.
mDesc
.
GetLengths
()[
2
],
arg
.
e_ms_ns_
.
mDesc
.
GetLengths
()[
3
])(
arg
.
e_ms_ns_
.
GetLengths
()[
0
],
arg
.
e_ms_ns_
.
GetLengths
()[
1
],
arg
.
e_ms_ns_
.
GetLengths
()[
2
],
arg
.
e_ms_ns_
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
...
...
@@ -288,26 +290,16 @@ int main(int argc, char* argv[])
exit
(
0
);
}
Tensor
<
ADataType
>
a_ms_ks
(
std
::
vector
<
std
::
size_t
>
(
a_ms_ks_lengths
.
begin
(),
a_ms_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
a_ms_ks_strides
.
begin
(),
a_ms_ks_strides
.
end
()));
Tensor
<
BDataType
>
b_ns_ks
(
std
::
vector
<
std
::
size_t
>
(
b_ns_ks_lengths
.
begin
(),
b_ns_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
b_ns_ks_strides
.
begin
(),
b_ns_ks_strides
.
end
()));
Tensor
<
EDataType
>
d_ms_ns
(
std
::
vector
<
std
::
size_t
>
(
d_ms_ns_lengths
.
begin
(),
d_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
d_ms_ns_strides
.
begin
(),
d_ms_ns_strides
.
end
()));
Tensor
<
EDataType
>
e_ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_lengths
.
begin
(),
e_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_strides
.
begin
(),
e_ms_ns_strides
.
end
()));
Tensor
<
EDataType
>
e_ms_ns_device_result
(
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_lengths
.
begin
(),
e_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_strides
.
begin
(),
e_ms_ns_strides
.
end
()));
std
::
cout
<<
"a_ms_ks: "
<<
a_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_ns_ks: "
<<
b_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_ms_ns: "
<<
d_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_ms_ns: "
<<
e_ms_ns_host_result
.
mDesc
<<
std
::
endl
;
Tensor
<
ADataType
>
a_ms_ks
(
a_ms_ks_lengths
,
a_ms_ks_strides
);
Tensor
<
BDataType
>
b_ns_ks
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
Tensor
<
EDataType
>
d_ms_ns
(
d_ms_ns_lengths
,
d_ms_ns_strides
);
Tensor
<
EDataType
>
e_ms_ns_host_result
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
Tensor
<
EDataType
>
e_ms_ns_device_result
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
std
::
cout
<<
"a_ms_ks: "
<<
a_ms_ks
.
GetDesc
()
<<
std
::
endl
;
std
::
cout
<<
"b_ns_ks: "
<<
b_ns_ks
.
GetDesc
()
<<
std
::
endl
;
std
::
cout
<<
"d_ms_ns: "
<<
d_ms_ns
.
GetDesc
()
<<
std
::
endl
;
std
::
cout
<<
"e_ms_ns: "
<<
e_ms_ns_host_result
.
GetDesc
()
<<
std
::
endl
;
switch
(
init_method
)
{
...
...
@@ -324,14 +316,14 @@ int main(int argc, char* argv[])
break
;
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_ms_ks
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_ns_ks
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_ms_ns
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
a_device_buf
(
a_ms_ks
.
GetMemory
Size
());
DeviceMem
b_device_buf
(
b_ns_ks
.
GetMemory
Size
());
DeviceMem
d_device_buf
(
d_ms_ns
.
GetMemory
Size
());
DeviceMem
e_device_buf
(
e_ms_ns_device_result
.
GetMemory
Size
());
a_device_buf
.
ToDevice
(
a_ms_ks
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_ns_ks
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_ms_ns
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_ms_ks
.
data
());
b_device_buf
.
ToDevice
(
b_ns_ks
.
data
());
d_device_buf
.
ToDevice
(
d_ms_ns
.
data
());
// set zero
e_device_buf
.
SetZero
();
...
...
@@ -340,19 +332,21 @@ int main(int argc, char* argv[])
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{
alpha
,
beta
};
using
ck
::
utils
::
to_array
;
// device operation
auto
op
=
DeviceOpInstance
{};
auto
invoker
=
op
.
MakeInvoker
();
auto
argument
=
op
.
MakeArgument
(
a_device_buf
.
GetDeviceBuffer
(),
b_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d_device_buf
.
GetDeviceBuffer
()},
to_array
(
{
d_device_buf
.
GetDeviceBuffer
()}
)
,
e_device_buf
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_ms_ns_strides
},
to_array
(
{
d_ms_ns_lengths
}
)
,
to_array
(
{
d_ms_ns_strides
}
)
,
e_ms_ns_lengths
,
e_ms_ns_strides
,
a_element_op
,
...
...
@@ -368,20 +362,14 @@ int main(int argc, char* argv[])
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
ck
::
index_t
M
=
std
::
accumulate
(
e_ms_ns_lengths
.
begin
(),
e_ms_ns_lengths
.
begin
()
+
NumDimM
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
M
=
ck
::
accumulate_n
(
e_ms_ns_lengths
.
begin
(),
NumDimM
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
N
=
std
::
accumulate
(
e_ms_ns_lengths
.
begin
()
+
NumDimM
,
e_ms_ns_lengths
.
begin
()
+
NumDimM
+
NumDimN
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
N
=
ck
::
accumulate_n
(
e_ms_ns_lengths
.
begin
()
+
NumDimM
,
NumDimN
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
K
=
std
::
accumulate
(
a_ms_ks_lengths
.
begin
()
+
NumDimM
,
a_ms_ks_lengths
.
begin
()
+
NumDimM
+
NumDimK
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
K
=
ck
::
accumulate_n
(
a_ms_ks_lengths
.
begin
()
+
NumDimM
,
NumDimK
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
...
...
@@ -394,13 +382,11 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op
.
GetTypeString
()
<<
std
::
endl
;
e_device_buf
.
FromDevice
(
e_ms_ns_device_result
.
mData
.
data
());
e_device_buf
.
FromDevice
(
e_ms_ns_device_result
.
data
());
if
(
do_verification
)
{
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_lengths
.
begin
(),
e_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_strides
.
begin
(),
e_ms_ns_strides
.
end
()));
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
using
ReferenceOpInstance
=
ReferenceContraction_M2_N2_K2
<
NumDimM
,
NumDimN
,
...
...
@@ -421,13 +407,13 @@ int main(int argc, char* argv[])
ref_invoker
.
Run
(
ref_argument
);
for
(
size_t
m0
=
0
;
m0
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
0
];
++
m0
)
for
(
size_t
m0
=
0
;
m0
<
e_ms_ns_host_result
.
GetLengths
()[
0
];
++
m0
)
{
for
(
size_t
m1
=
0
;
m1
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
1
];
++
m1
)
for
(
size_t
m1
=
0
;
m1
<
e_ms_ns_host_result
.
GetLengths
()[
1
];
++
m1
)
{
for
(
size_t
n0
=
0
;
n0
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
2
];
++
n0
)
for
(
size_t
n0
=
0
;
n0
<
e_ms_ns_host_result
.
GetLengths
()[
2
];
++
n0
)
{
for
(
size_t
n1
=
0
;
n1
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
3
];
++
n1
)
for
(
size_t
n1
=
0
;
n1
<
e_ms_ns_host_result
.
GetLengths
()[
3
];
++
n1
)
{
cde_element_op
(
e_ms_ns_host_result
(
m0
,
m1
,
n0
,
n1
),
c_ms_ns_host_result
(
m0
,
m1
,
n0
,
n1
),
...
...
@@ -437,7 +423,7 @@ int main(int argc, char* argv[])
}
}
return
ck
::
utils
::
check_err
(
e_ms_ns_device_result
.
mData
,
e_ms_ns_host_result
.
mData
)
?
0
:
1
;
return
ck
::
utils
::
check_err
(
e_ms_ns_device_result
,
e_ms_ns_host_result
)
?
0
:
1
;
}
return
0
;
...
...
example/26_contraction/contraction_scale_xdl_fp32.cpp
View file @
e4e99a49
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/array.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
@@ -121,8 +123,8 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
float
Run
(
const
Argument
&
arg
)
{
auto
f_ms_ns
=
[
&
](
auto
m0
,
auto
m1
,
auto
n0
,
auto
n1
)
{
const
int
K0
=
arg
.
a_ms_ks_
.
mDesc
.
GetLengths
()[
2
];
const
int
K1
=
arg
.
a_ms_ks_
.
mDesc
.
GetLengths
()[
3
];
const
int
K0
=
arg
.
a_ms_ks_
.
GetLengths
()[
2
];
const
int
K1
=
arg
.
a_ms_ks_
.
GetLengths
()[
3
];
AccDataType
v_acc
=
0
;
...
...
@@ -150,10 +152,10 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
};
make_ParallelTensorFunctor
(
f_ms_ns
,
arg
.
e_ms_ns_
.
mDesc
.
GetLengths
()[
0
],
arg
.
e_ms_ns_
.
mDesc
.
GetLengths
()[
1
],
arg
.
e_ms_ns_
.
mDesc
.
GetLengths
()[
2
],
arg
.
e_ms_ns_
.
mDesc
.
GetLengths
()[
3
])(
arg
.
e_ms_ns_
.
GetLengths
()[
0
],
arg
.
e_ms_ns_
.
GetLengths
()[
1
],
arg
.
e_ms_ns_
.
GetLengths
()[
2
],
arg
.
e_ms_ns_
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
...
...
@@ -277,22 +279,14 @@ int main(int argc, char* argv[])
exit
(
0
);
}
Tensor
<
ADataType
>
a_ms_ks
(
std
::
vector
<
std
::
size_t
>
(
a_ms_ks_lengths
.
begin
(),
a_ms_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
a_ms_ks_strides
.
begin
(),
a_ms_ks_strides
.
end
()));
Tensor
<
BDataType
>
b_ns_ks
(
std
::
vector
<
std
::
size_t
>
(
b_ns_ks_lengths
.
begin
(),
b_ns_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
b_ns_ks_strides
.
begin
(),
b_ns_ks_strides
.
end
()));
Tensor
<
EDataType
>
e_ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_lengths
.
begin
(),
e_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_strides
.
begin
(),
e_ms_ns_strides
.
end
()));
Tensor
<
EDataType
>
e_ms_ns_device_result
(
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_lengths
.
begin
(),
e_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_strides
.
begin
(),
e_ms_ns_strides
.
end
()));
std
::
cout
<<
"a_ms_ks: "
<<
a_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_ns_ks: "
<<
b_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_ms_ns: "
<<
e_ms_ns_host_result
.
mDesc
<<
std
::
endl
;
Tensor
<
ADataType
>
a_ms_ks
(
a_ms_ks_lengths
,
a_ms_ks_strides
);
Tensor
<
BDataType
>
b_ns_ks
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
Tensor
<
EDataType
>
e_ms_ns_host_result
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
Tensor
<
EDataType
>
e_ms_ns_device_result
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
std
::
cout
<<
"a_ms_ks: "
<<
a_ms_ks
.
GetDesc
()
<<
std
::
endl
;
std
::
cout
<<
"b_ns_ks: "
<<
b_ns_ks
.
GetDesc
()
<<
std
::
endl
;
std
::
cout
<<
"e_ms_ns: "
<<
e_ms_ns_host_result
.
GetDesc
()
<<
std
::
endl
;
switch
(
init_method
)
{
...
...
@@ -307,12 +301,12 @@ int main(int argc, char* argv[])
break
;
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_ms_ks
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_ns_ks
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
a_device_buf
(
a_ms_ks
.
GetMemory
Size
());
DeviceMem
b_device_buf
(
b_ns_ks
.
GetMemory
Size
());
DeviceMem
e_device_buf
(
e_ms_ns_device_result
.
GetMemory
Size
());
a_device_buf
.
ToDevice
(
a_ms_ks
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_ns_ks
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_ms_ks
.
data
());
b_device_buf
.
ToDevice
(
b_ns_ks
.
data
());
// set zero
e_device_buf
.
SetZero
();
...
...
@@ -321,19 +315,21 @@ int main(int argc, char* argv[])
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{
scale
};
using
ck
::
utils
::
empty_array
;
// device operation
auto
op
=
DeviceOpInstance
{};
auto
invoker
=
op
.
MakeInvoker
();
auto
argument
=
op
.
MakeArgument
(
a_device_buf
.
GetDeviceBuffer
(),
b_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
0
>
{}
,
empty_array
()
,
e_device_buf
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
0
>
{}
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
0
>
{}
,
empty_array
()
,
empty_array
()
,
e_ms_ns_lengths
,
e_ms_ns_strides
,
a_element_op
,
...
...
@@ -349,20 +345,14 @@ int main(int argc, char* argv[])
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
ck
::
index_t
M
=
std
::
accumulate
(
e_ms_ns_lengths
.
begin
(),
e_ms_ns_lengths
.
begin
()
+
NumDimM
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
M
=
ck
::
accumulate_n
(
e_ms_ns_lengths
.
begin
(),
NumDimM
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
N
=
std
::
accumulate
(
e_ms_ns_lengths
.
begin
()
+
NumDimM
,
e_ms_ns_lengths
.
begin
()
+
NumDimM
+
NumDimN
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
N
=
ck
::
accumulate_n
(
e_ms_ns_lengths
.
begin
()
+
NumDimM
,
NumDimN
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
K
=
std
::
accumulate
(
a_ms_ks_lengths
.
begin
()
+
NumDimM
,
a_ms_ks_lengths
.
begin
()
+
NumDimM
+
NumDimK
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
K
=
ck
::
accumulate_n
(
a_ms_ks_lengths
.
begin
()
+
NumDimM
,
NumDimK
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
...
...
@@ -375,13 +365,11 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op
.
GetTypeString
()
<<
std
::
endl
;
e_device_buf
.
FromDevice
(
e_ms_ns_device_result
.
mData
.
data
());
e_device_buf
.
FromDevice
(
e_ms_ns_device_result
.
data
());
if
(
do_verification
)
{
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_lengths
.
begin
(),
e_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_strides
.
begin
(),
e_ms_ns_strides
.
end
()));
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
using
ReferenceOpInstance
=
ReferenceContraction_M2_N2_K2
<
NumDimM
,
NumDimN
,
...
...
@@ -402,13 +390,13 @@ int main(int argc, char* argv[])
ref_invoker
.
Run
(
ref_argument
);
for
(
size_t
m0
=
0
;
m0
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
0
];
++
m0
)
for
(
size_t
m0
=
0
;
m0
<
e_ms_ns_host_result
.
GetLengths
()[
0
];
++
m0
)
{
for
(
size_t
m1
=
0
;
m1
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
1
];
++
m1
)
for
(
size_t
m1
=
0
;
m1
<
e_ms_ns_host_result
.
GetLengths
()[
1
];
++
m1
)
{
for
(
size_t
n0
=
0
;
n0
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
2
];
++
n0
)
for
(
size_t
n0
=
0
;
n0
<
e_ms_ns_host_result
.
GetLengths
()[
2
];
++
n0
)
{
for
(
size_t
n1
=
0
;
n1
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
3
];
++
n1
)
for
(
size_t
n1
=
0
;
n1
<
e_ms_ns_host_result
.
GetLengths
()[
3
];
++
n1
)
{
cde_element_op
(
e_ms_ns_host_result
(
m0
,
m1
,
n0
,
n1
),
c_ms_ns_host_result
(
m0
,
m1
,
n0
,
n1
));
...
...
@@ -417,7 +405,7 @@ int main(int argc, char* argv[])
}
}
return
ck
::
utils
::
check_err
(
e_ms_ns_device_result
.
mData
,
e_ms_ns_host_result
.
mData
)
?
0
:
1
;
return
ck
::
utils
::
check_err
(
e_ms_ns_device_result
,
e_ms_ns_host_result
)
?
0
:
1
;
}
return
0
;
...
...
example/27_layernorm/layernorm_blockwise.cpp
View file @
e4e99a49
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h>
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/ranges.hpp"
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
...
...
@@ -59,14 +62,14 @@ int main()
ck
::
index_t
N
=
1024
;
ck
::
index_t
Stride
=
N
;
using
namespace
ck
::
literals
;
auto
f_host_tensor_descriptor1d
=
[](
std
::
size_t
len
,
std
::
size_t
stride
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
len
}),
std
::
vector
<
std
::
size_t
>
({
stride
}));
return
HostTensorDescriptor
({
len
},
{
stride
});
};
auto
f_host_tensor_descriptor2d
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
};
Tensor
<
XDataType
>
x
(
f_host_tensor_descriptor2d
(
M
,
N
,
Stride
));
...
...
@@ -78,29 +81,30 @@ int main()
gamma
.
GenerateTensorValue
(
GeneratorTensor_3
<
GammaDataType
>
{
0.0
,
1.0
});
beta
.
GenerateTensorValue
(
GeneratorTensor_3
<
BetaDataType
>
{
0.0
,
1.0
});
DeviceMem
x_dev
(
sizeof
(
XDataType
)
*
x
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
gamma_dev
(
sizeof
(
GammaDataType
)
*
gamma
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
beta_dev
(
sizeof
(
BetaDataType
)
*
beta
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_dev
(
sizeof
(
YDataType
)
*
y
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
x_dev
(
x
.
GetMemorySize
());
DeviceMem
gamma_dev
(
gamma
.
GetMemorySize
());
DeviceMem
beta_dev
(
beta
.
GetMemorySize
());
DeviceMem
y_dev
(
y
.
GetMemorySize
());
x_dev
.
ToDevice
(
x
.
data
());
gamma_dev
.
ToDevice
(
gamma
.
data
());
beta_dev
.
ToDevice
(
beta
.
data
());
x_dev
.
ToDevice
(
x
.
mData
.
data
());
gamma_dev
.
ToDevice
(
gamma
.
mData
.
data
());
beta_dev
.
ToDevice
(
beta
.
mData
.
data
());
using
Indices
=
std
::
vector
<
ck
::
index_t
>
;
auto
device_instance
=
DeviceInstance
{};
auto
argument_ptr
=
device_instance
.
MakeArgumentPointer
(
{
M
,
N
},
std
::
vector
<
ck
::
index_t
>
{
x
.
mDesc
.
GetStrides
().
begin
(),
x
.
mDesc
.
GetStrides
().
end
()},
{
0
,
1
},
{
0
,
1
},
std
::
vector
<
ck
::
index_t
>
{
y
.
mDesc
.
GetStrides
().
begin
(),
y
.
mDesc
.
GetStrides
().
end
()},
{
1
},
1e-4
,
x_dev
.
GetDeviceBuffer
(),
gamma_dev
.
GetDeviceBuffer
(),
beta_dev
.
GetDeviceBuffer
(),
y_dev
.
GetDeviceBuffer
(),
PassThrough
{});
auto
argument_ptr
=
device_instance
.
MakeArgumentPointer
({
M
,
N
},
ck
::
ranges
::
to
<
Indices
>
(
x
.
GetStrides
()),
{
0
,
1
},
{
0
,
1
},
ck
::
ranges
::
to
<
Indices
>
(
y
.
GetStrides
()),
{
1
},
1e-4
,
x_dev
.
GetDeviceBuffer
(),
gamma_dev
.
GetDeviceBuffer
(),
beta_dev
.
GetDeviceBuffer
(),
y_dev
.
GetDeviceBuffer
(),
PassThrough
{});
if
(
!
device_instance
.
IsSupportedArgument
(
argument_ptr
.
get
()))
{
...
...
@@ -129,9 +133,8 @@ int main()
auto
ref_invoker
=
ref
.
MakeInvoker
();
ref_invoker
.
Run
(
ref_argument
);
y_dev
.
FromDevice
(
y
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
y
.
mData
,
host_y
.
mData
,
"Error: Incorrect results d1"
,
1e-3
,
1e-3
);
y_dev
.
FromDevice
(
y
.
data
());
pass
&=
ck
::
utils
::
check_err
(
y
,
host_y
,
"Error: Incorrect results d1"
,
1e-3
,
1e-3
);
}
return
(
pass
?
0
:
1
);
}
example/28_grouped_gemm_bias_e_permute/grouped_gemm_bias_e_permute_xdl_fp16.cpp
View file @
e4e99a49
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/numeric.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
@@ -104,7 +106,7 @@ struct ReferenceContraction_M3_N2_K1 : public ck::tensor_operation::device::Base
float
Run
(
const
Argument
&
arg
)
{
auto
f_ms_ns
=
[
&
](
auto
m0
,
auto
m1
,
auto
m2
,
auto
n0
,
auto
n1
)
{
const
int
K0
=
arg
.
a_ms_ks_
.
mDesc
.
GetLengths
()[
3
];
const
int
K0
=
arg
.
a_ms_ks_
.
GetLengths
()[
3
];
AccDataType
v_acc
=
0
;
...
...
@@ -129,11 +131,11 @@ struct ReferenceContraction_M3_N2_K1 : public ck::tensor_operation::device::Base
};
make_ParallelTensorFunctor
(
f_ms_ns
,
arg
.
e_ms_ns_
.
mDesc
.
GetLengths
()[
0
],
arg
.
e_ms_ns_
.
mDesc
.
GetLengths
()[
1
],
arg
.
e_ms_ns_
.
mDesc
.
GetLengths
()[
2
],
arg
.
e_ms_ns_
.
mDesc
.
GetLengths
()[
3
],
arg
.
e_ms_ns_
.
mDesc
.
GetLengths
()[
4
])(
arg
.
e_ms_ns_
.
GetLengths
()[
0
],
arg
.
e_ms_ns_
.
GetLengths
()[
1
],
arg
.
e_ms_ns_
.
GetLengths
()[
2
],
arg
.
e_ms_ns_
.
GetLengths
()[
3
],
arg
.
e_ms_ns_
.
GetLengths
()[
4
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
...
...
@@ -297,33 +299,23 @@ int main(int argc, char* argv[])
const
auto
e_ms_ns_lengths
=
contraction_descs
[
i
].
e_ms_ns_lengths
;
const
auto
e_ms_ns_strides
=
contraction_descs
[
i
].
e_ms_ns_strides
;
Tensor
<
ADataType
>
a_ms_ks
(
std
::
vector
<
std
::
size_t
>
(
a_ms_ks_lengths
.
begin
(),
a_ms_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
a_ms_ks_strides
.
begin
(),
a_ms_ks_strides
.
end
()));
Tensor
<
BDataType
>
b_ns_ks
(
std
::
vector
<
std
::
size_t
>
(
b_ns_ks_lengths
.
begin
(),
b_ns_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
b_ns_ks_strides
.
begin
(),
b_ns_ks_strides
.
end
()));
Tensor
<
DDataType
>
d_ms_ns
(
std
::
vector
<
std
::
size_t
>
(
d_ms_ns_lengths
.
begin
(),
d_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
d_ms_ns_strides
.
begin
(),
d_ms_ns_strides
.
end
()));
Tensor
<
EDataType
>
e_ms_ns_device_result
(
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_lengths
.
begin
(),
e_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_strides
.
begin
(),
e_ms_ns_strides
.
end
()));
ck
::
index_t
M_
=
std
::
accumulate
(
e_ms_ns_lengths
.
begin
(),
e_ms_ns_lengths
.
begin
()
+
NumDimM
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
N_
=
std
::
accumulate
(
e_ms_ns_lengths
.
begin
()
+
NumDimM
,
e_ms_ns_lengths
.
begin
()
+
NumDimM
+
NumDimN
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
K_
=
std
::
accumulate
(
a_ms_ks_lengths
.
begin
()
+
NumDimM
,
a_ms_ks_lengths
.
begin
()
+
NumDimM
+
NumDimK
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
Tensor
<
ADataType
>
a_ms_ks
(
a_ms_ks_lengths
,
a_ms_ks_strides
);
Tensor
<
BDataType
>
b_ns_ks
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
Tensor
<
DDataType
>
d_ms_ns
(
d_ms_ns_lengths
,
d_ms_ns_strides
);
Tensor
<
EDataType
>
e_ms_ns_device_result
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
ck
::
index_t
M_
=
ck
::
accumulate_n
(
e_ms_ns_lengths
.
begin
(),
NumDimM
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
N_
=
ck
::
accumulate_n
(
e_ms_ns_lengths
.
begin
()
+
NumDimM
,
NumDimN
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
K_
=
ck
::
accumulate_n
(
a_ms_ks_lengths
.
begin
()
+
NumDimM
,
NumDimK
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
a_tensors
.
push_back
(
a_ms_ks
);
b_tensors
.
push_back
(
b_ns_ks
);
...
...
@@ -334,13 +326,13 @@ int main(int argc, char* argv[])
flop
+=
std
::
size_t
(
2
)
*
M_
*
K_
*
N_
;
num_btype
+=
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
EDataType
)
*
e_device_tensors
[
i
].
mDesc
.
GetElementSize
();
num_btype
+=
sizeof
(
ADataType
)
*
a_tensors
[
i
].
GetElementSize
()
+
sizeof
(
BDataType
)
*
b_tensors
[
i
].
GetElementSize
()
+
sizeof
(
EDataType
)
*
e_device_tensors
[
i
].
GetElementSize
();
std
::
cout
<<
"gemm["
<<
i
<<
"] a_m_k: "
<<
a_tensors
[
i
].
m
Desc
<<
" b_n_k: "
<<
b_tensors
[
i
].
mDesc
<<
" c_m_n: "
<<
e_device_tensors
[
i
].
m
Desc
<<
std
::
endl
;
std
::
cout
<<
"gemm["
<<
i
<<
"] a_m_k: "
<<
a_tensors
[
i
].
Get
Desc
()
<<
" b_n_k: "
<<
b_tensors
[
i
].
Get
Desc
()
<<
" c_m_n: "
<<
e_device_tensors
[
i
].
GetDesc
()
<<
std
::
endl
;
switch
(
init_method
)
{
...
...
@@ -364,18 +356,15 @@ int main(int argc, char* argv[])
for
(
std
::
size_t
i
=
0
;
i
<
contraction_descs
.
size
();
i
++
)
{
a_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSpaceSize
()));
b_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
BDataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSpaceSize
()));
d_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
DDataType
)
*
d_tensors
[
i
].
mDesc
.
GetElementSpaceSize
()));
e_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
EDataType
)
*
e_device_tensors
[
i
].
mDesc
.
GetElementSpaceSize
()));
a_tensors_device
[
i
]
->
ToDevice
(
a_tensors
[
i
].
mData
.
data
());
b_tensors_device
[
i
]
->
ToDevice
(
b_tensors
[
i
].
mData
.
data
());
d_tensors_device
[
i
]
->
ToDevice
(
d_tensors
[
i
].
mData
.
data
());
a_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
a_tensors
[
i
].
GetMemorySize
()));
b_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
b_tensors
[
i
].
GetMemorySize
()));
d_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
d_tensors
[
i
].
GetMemorySize
()));
e_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
e_device_tensors
[
i
].
GetMemorySize
()));
a_tensors_device
[
i
]
->
ToDevice
(
a_tensors
[
i
].
data
());
b_tensors_device
[
i
]
->
ToDevice
(
b_tensors
[
i
].
data
());
d_tensors_device
[
i
]
->
ToDevice
(
d_tensors
[
i
].
data
());
p_a
.
push_back
(
a_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_b
.
push_back
(
b_tensors_device
[
i
]
->
GetDeviceBuffer
());
...
...
@@ -423,15 +412,11 @@ int main(int argc, char* argv[])
const
auto
e_ms_ns_lengths
=
contraction_descs
[
i
].
e_ms_ns_lengths
;
const
auto
e_ms_ns_strides
=
contraction_descs
[
i
].
e_ms_ns_strides
;
Tensor
<
EDataType
>
c_ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_lengths
.
begin
(),
e_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_strides
.
begin
(),
e_ms_ns_strides
.
end
()));
Tensor
<
EDataType
>
c_ms_ns_host_result
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
Tensor
<
EDataType
>
e_ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_lengths
.
begin
(),
e_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_ms_ns_strides
.
begin
(),
e_ms_ns_strides
.
end
()));
Tensor
<
EDataType
>
e_ms_ns_host_result
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
e_tensors_device
[
i
]
->
FromDevice
(
e_device_tensors
[
i
].
mData
.
data
());
e_tensors_device
[
i
]
->
FromDevice
(
e_device_tensors
[
i
].
data
());
using
ReferenceOpInstance
=
ReferenceContraction_M3_N2_K1
<
NumDimM
,
NumDimN
,
...
...
@@ -456,15 +441,15 @@ int main(int argc, char* argv[])
ref_invoker
.
Run
(
ref_argument
);
for
(
size_t
m0
=
0
;
m0
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
0
];
++
m0
)
for
(
size_t
m0
=
0
;
m0
<
e_ms_ns_host_result
.
GetLengths
()[
0
];
++
m0
)
{
for
(
size_t
m1
=
0
;
m1
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
1
];
++
m1
)
for
(
size_t
m1
=
0
;
m1
<
e_ms_ns_host_result
.
GetLengths
()[
1
];
++
m1
)
{
for
(
size_t
m2
=
0
;
m2
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
2
];
++
m2
)
for
(
size_t
m2
=
0
;
m2
<
e_ms_ns_host_result
.
GetLengths
()[
2
];
++
m2
)
{
for
(
size_t
n0
=
0
;
n0
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
3
];
++
n0
)
for
(
size_t
n0
=
0
;
n0
<
e_ms_ns_host_result
.
GetLengths
()[
3
];
++
n0
)
{
for
(
size_t
n1
=
0
;
n1
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
4
];
++
n1
)
for
(
size_t
n1
=
0
;
n1
<
e_ms_ns_host_result
.
GetLengths
()[
4
];
++
n1
)
{
cde_element_op
(
e_ms_ns_host_result
(
m0
,
m1
,
m2
,
n0
,
n1
),
c_ms_ns_host_result
(
m0
,
m1
,
m2
,
n0
,
n1
),
...
...
@@ -475,7 +460,7 @@ int main(int argc, char* argv[])
}
}
pass
&=
ck
::
utils
::
check_err
(
e_device_tensors
[
i
]
.
mData
,
e_ms_ns_host_result
.
mData
);
pass
&=
ck
::
utils
::
check_err
(
e_device_tensors
[
i
],
e_ms_ns_host_result
);
}
}
...
...
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp
View file @
e4e99a49
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/array.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
@@ -109,7 +111,7 @@ struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::B
float
Run
(
const
Argument
&
arg
)
{
auto
f_ms_ns
=
[
&
](
auto
g0
,
auto
g1
,
auto
m0
,
auto
m1
,
auto
n0
,
auto
n1
)
{
const
int
K0
=
arg
.
a_gs_ms_ks_
.
mDesc
.
GetLengths
()[
4
];
const
int
K0
=
arg
.
a_gs_ms_ks_
.
GetLengths
()[
4
];
AccDataType
v_acc
=
0
;
...
...
@@ -136,12 +138,12 @@ struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::B
};
make_ParallelTensorFunctor
(
f_ms_ns
,
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
0
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
1
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
2
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
3
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
4
],
arg
.
e_gs_ms_ns_
.
mDesc
.
GetLengths
()[
5
])(
arg
.
e_gs_ms_ns_
.
GetLengths
()[
0
],
arg
.
e_gs_ms_ns_
.
GetLengths
()[
1
],
arg
.
e_gs_ms_ns_
.
GetLengths
()[
2
],
arg
.
e_gs_ms_ns_
.
GetLengths
()[
3
],
arg
.
e_gs_ms_ns_
.
GetLengths
()[
4
],
arg
.
e_gs_ms_ns_
.
GetLengths
()[
5
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
...
...
@@ -246,26 +248,16 @@ int main(int argc, char* argv[])
exit
(
0
);
}
Tensor
<
ADataType
>
a_gs_ms_ks
(
std
::
vector
<
std
::
size_t
>
(
a_gs_ms_ks_lengths
.
begin
(),
a_gs_ms_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
a_gs_ms_ks_strides
.
begin
(),
a_gs_ms_ks_strides
.
end
()));
Tensor
<
BDataType
>
b_gs_ns_ks
(
std
::
vector
<
std
::
size_t
>
(
b_gs_ns_ks_lengths
.
begin
(),
b_gs_ns_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
b_gs_ns_ks_strides
.
begin
(),
b_gs_ns_ks_strides
.
end
()));
Tensor
<
DDataType
>
d_gs_ms_ns
(
std
::
vector
<
std
::
size_t
>
(
d_gs_ms_ns_lengths
.
begin
(),
d_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
d_gs_ms_ns_strides
.
begin
(),
d_gs_ms_ns_strides
.
end
()));
Tensor
<
EDataType
>
e_gs_ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_lengths
.
begin
(),
e_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_strides
.
begin
(),
e_gs_ms_ns_strides
.
end
()));
Tensor
<
EDataType
>
e_gs_ms_ns_device_result
(
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_lengths
.
begin
(),
e_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_strides
.
begin
(),
e_gs_ms_ns_strides
.
end
()));
std
::
cout
<<
"a_gs_ms_ks: "
<<
a_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_gs_ns_ks: "
<<
b_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_gs_ms_ns: "
<<
d_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_gs_ms_ns: "
<<
e_gs_ms_ns_host_result
.
mDesc
<<
std
::
endl
;
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
BDataType
>
b_gs_ns_ks
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
);
Tensor
<
DDataType
>
d_gs_ms_ns
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
Tensor
<
EDataType
>
e_gs_ms_ns_host_result
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
);
Tensor
<
EDataType
>
e_gs_ms_ns_device_result
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
);
std
::
cout
<<
"a_gs_ms_ks: "
<<
a_gs_ms_ks
.
GetDesc
()
<<
std
::
endl
;
std
::
cout
<<
"b_gs_ns_ks: "
<<
b_gs_ns_ks
.
GetDesc
()
<<
std
::
endl
;
std
::
cout
<<
"d_gs_ms_ns: "
<<
d_gs_ms_ns
.
GetDesc
()
<<
std
::
endl
;
std
::
cout
<<
"e_gs_ms_ns: "
<<
e_gs_ms_ns_host_result
.
GetDesc
()
<<
std
::
endl
;
switch
(
init_method
)
{
...
...
@@ -282,15 +274,14 @@ int main(int argc, char* argv[])
break
;
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_gs_ms_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_gs_ns_ks
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_gs_ms_ns
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_gs_ms_ns_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_device_buf
(
a_gs_ms_ks
.
GetMemorySize
());
DeviceMem
b_device_buf
(
b_gs_ns_ks
.
GetMemorySize
());
DeviceMem
d_device_buf
(
d_gs_ms_ns
.
GetMemorySize
());
DeviceMem
e_device_buf
(
e_gs_ms_ns_device_result
.
GetMemorySize
());
a_device_buf
.
ToDevice
(
a_gs_ms_ks
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_gs_ns_ks
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_gs_ms_ns
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_gs_ms_ks
.
data
());
b_device_buf
.
ToDevice
(
b_gs_ns_ks
.
data
());
d_device_buf
.
ToDevice
(
d_gs_ms_ns
.
data
());
// set zero
e_device_buf
.
SetZero
();
...
...
@@ -299,19 +290,21 @@ int main(int argc, char* argv[])
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
using
ck
::
utils
::
to_array
;
// device operation
auto
op
=
DeviceOpInstance
{};
auto
invoker
=
op
.
MakeInvoker
();
auto
argument
=
op
.
MakeArgument
(
a_device_buf
.
GetDeviceBuffer
(),
b_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d_device_buf
.
GetDeviceBuffer
()},
to_array
(
{
d_device_buf
.
GetDeviceBuffer
()}
)
,
e_device_buf
.
GetDeviceBuffer
(),
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_gs_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_gs_ms_ns_strides
},
to_array
(
{
d_gs_ms_ns_lengths
}
)
,
to_array
(
{
d_gs_ms_ns_strides
}
)
,
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
,
a_element_op
,
...
...
@@ -327,25 +320,23 @@ int main(int argc, char* argv[])
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
ck
::
index_t
G
=
std
::
accumulate
(
e_gs_ms_ns_lengths
.
begin
(),
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
G
=
ck
::
accumulate_n
(
e_gs_ms_ns_lengths
.
begin
(),
NumDimG
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
M
=
std
::
accumulate
(
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
,
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
M
=
ck
::
accumulate
_n
(
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
,
NumDimM
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
N
=
std
::
accumulate
(
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
+
NumDimM
+
NumDimN
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
N
=
ck
::
accumulate
_n
(
e_gs_ms_ns_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
NumDimN
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
K
=
std
::
accumulate
(
a_gs_ms_ks_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
a_gs_ms_ks_lengths
.
begin
()
+
NumDimG
+
NumDimM
+
NumDimK
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
ck
::
index_t
K
=
ck
::
accumulate
_n
(
a_gs_ms_ks_lengths
.
begin
()
+
NumDimG
+
NumDimM
,
NumDimK
,
ck
::
index_t
{
1
},
std
::
multiplies
<
ck
::
index_t
>
{});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
G
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
G
*
M
*
K
+
sizeof
(
BDataType
)
*
G
*
K
*
N
+
...
...
@@ -358,13 +349,11 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op
.
GetTypeString
()
<<
std
::
endl
;
e_device_buf
.
FromDevice
(
e_gs_ms_ns_device_result
.
mData
.
data
());
e_device_buf
.
FromDevice
(
e_gs_ms_ns_device_result
.
data
());
if
(
do_verification
)
{
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_lengths
.
begin
(),
e_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_strides
.
begin
(),
e_gs_ms_ns_strides
.
end
()));
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
);
using
ReferenceOpInstance
=
ReferenceContraction_G2_M2_N2_K1
<
NumDimG
,
NumDimM
,
...
...
@@ -386,18 +375,17 @@ int main(int argc, char* argv[])
ref_invoker
.
Run
(
ref_argument
);
for
(
size_t
g0
=
0
;
g0
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
0
];
++
g0
)
for
(
size_t
g0
=
0
;
g0
<
e_gs_ms_ns_host_result
.
GetLengths
()[
0
];
++
g0
)
{
for
(
size_t
g1
=
0
;
g1
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
1
];
++
g1
)
for
(
size_t
g1
=
0
;
g1
<
e_gs_ms_ns_host_result
.
GetLengths
()[
1
];
++
g1
)
{
for
(
size_t
m0
=
0
;
m0
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
2
];
++
m0
)
for
(
size_t
m0
=
0
;
m0
<
e_gs_ms_ns_host_result
.
GetLengths
()[
2
];
++
m0
)
{
for
(
size_t
m1
=
0
;
m1
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
3
];
++
m1
)
for
(
size_t
m1
=
0
;
m1
<
e_gs_ms_ns_host_result
.
GetLengths
()[
3
];
++
m1
)
{
for
(
size_t
n0
=
0
;
n0
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
4
];
++
n0
)
for
(
size_t
n0
=
0
;
n0
<
e_gs_ms_ns_host_result
.
GetLengths
()[
4
];
++
n0
)
{
for
(
size_t
n1
=
0
;
n1
<
e_gs_ms_ns_host_result
.
mDesc
.
GetLengths
()[
5
];
++
n1
)
for
(
size_t
n1
=
0
;
n1
<
e_gs_ms_ns_host_result
.
GetLengths
()[
5
];
++
n1
)
{
cde_element_op
(
e_gs_ms_ns_host_result
(
g0
,
g1
,
m0
,
m1
,
n0
,
n1
),
c_ms_ns_host_result
(
g0
,
g1
,
m0
,
m1
,
n0
,
n1
),
...
...
@@ -409,9 +397,7 @@ int main(int argc, char* argv[])
}
}
return
ck
::
utils
::
check_err
(
e_gs_ms_ns_device_result
.
mData
,
e_gs_ms_ns_host_result
.
mData
)
?
0
:
1
;
return
ck
::
utils
::
check_err
(
e_gs_ms_ns_device_result
,
e_gs_ms_ns_host_result
)
?
0
:
1
;
}
return
0
;
...
...
example/30_grouped_convnd_fwd_bias_relu_add/grouped_convnd_fwd_bias_relu_add_common.hpp
View file @
e4e99a49
...
...
@@ -10,12 +10,14 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/array.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
void
print_helper_msg
()
{
...
...
@@ -57,11 +59,11 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification,
Tensor
<
OutUserDataType
>
out_host
(
out_g_n_k_wos_desc
);
Tensor
<
OutKernelDataType
>
out_device
(
out_g_n_k_wos_desc
);
std
::
cout
<<
"in: "
<<
in
.
m
Desc
<<
std
::
endl
;
std
::
cout
<<
"wei: "
<<
wei
.
m
Desc
<<
std
::
endl
;
std
::
cout
<<
"bias: "
<<
bias
.
m
Desc
<<
std
::
endl
;
std
::
cout
<<
"residual: "
<<
residual
.
m
Desc
<<
std
::
endl
;
std
::
cout
<<
"out: "
<<
out_host
.
m
Desc
<<
std
::
endl
;
std
::
cout
<<
"in: "
<<
in
.
Get
Desc
()
<<
std
::
endl
;
std
::
cout
<<
"wei: "
<<
wei
.
Get
Desc
()
<<
std
::
endl
;
std
::
cout
<<
"bias: "
<<
bias
.
Get
Desc
()
<<
std
::
endl
;
std
::
cout
<<
"residual: "
<<
residual
.
Get
Desc
()
<<
std
::
endl
;
std
::
cout
<<
"out: "
<<
out_host
.
Get
Desc
()
<<
std
::
endl
;
switch
(
init_method
)
{
...
...
@@ -77,11 +79,11 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification,
bias
.
GenerateTensorValue
(
GeneratorTensor_3
<
OutUserDataType
>
{
-
0.5
,
0.5
});
}
DeviceMem
in_device_buf
(
sizeof
(
InKernelDataType
)
*
in
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiKernelDataType
)
*
wei
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
bias_device_buf
(
sizeof
(
OutKernelDataType
)
*
bias
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
residual_device_buf
(
sizeof
(
OutKernelDataType
)
*
residual
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
out_device_buf
(
sizeof
(
OutKernelDataType
)
*
out_device
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
in_device_buf
(
in
.
GetMemory
Size
());
DeviceMem
wei_device_buf
(
wei
.
GetMemory
Size
());
DeviceMem
bias_device_buf
(
bias
.
GetMemory
Size
());
DeviceMem
residual_device_buf
(
residual
.
GetMemory
Size
());
DeviceMem
out_device_buf
(
out_device
.
GetMemory
Size
());
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
const
Tensor
<
InKernelDataType
>
in_converted
(
in
);
...
...
@@ -89,75 +91,54 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification,
const
Tensor
<
OutKernelDataType
>
bias_converted
(
bias
);
const
Tensor
<
OutKernelDataType
>
residual_converted
(
residual
);
in_device_buf
.
ToDevice
(
in_converted
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_converted
.
mData
.
data
());
bias_device_buf
.
ToDevice
(
bias_converted
.
mData
.
data
());
residual_device_buf
.
ToDevice
(
residual_converted
.
mData
.
data
());
in_device_buf
.
ToDevice
(
in_converted
.
data
());
wei_device_buf
.
ToDevice
(
wei_converted
.
data
());
bias_device_buf
.
ToDevice
(
bias_converted
.
data
());
residual_device_buf
.
ToDevice
(
residual_converted
.
data
());
#else // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
in_device_buf
.
ToDevice
(
in
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei
.
mData
.
data
());
bias_device_buf
.
ToDevice
(
bias
.
mData
.
data
());
residual_device_buf
.
ToDevice
(
residual
.
mData
.
data
());
in_device_buf
.
ToDevice
(
in
.
data
());
wei_device_buf
.
ToDevice
(
wei
.
data
());
bias_device_buf
.
ToDevice
(
bias
.
data
());
residual_device_buf
.
ToDevice
(
residual
.
data
());
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
d0_g_n_k_wos_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
d0_g_n_k_wos_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
d1_g_n_k_wos_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
d1_g_n_k_wos_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
{};
auto
copy
=
[](
auto
&
x
,
auto
&
y
)
{
std
::
copy
(
x
.
begin
(),
x
.
end
(),
y
.
begin
());
};
copy
(
in_g_n_c_wis_desc
.
GetLengths
(),
a_g_n_c_wis_lengths
);
copy
(
in_g_n_c_wis_desc
.
GetStrides
(),
a_g_n_c_wis_strides
);
copy
(
wei_g_k_c_xs_desc
.
GetLengths
(),
b_g_k_c_xs_lengths
);
copy
(
wei_g_k_c_xs_desc
.
GetStrides
(),
b_g_k_c_xs_strides
);
auto
copy
=
[](
const
auto
&
x
,
auto
&
y
)
{
ck
::
ranges
::
copy
(
x
,
y
.
begin
());
};
copy
(
bias_g_n_k_wos_desc
.
GetLengths
(),
d0_g_n_k_wos_lengths
);
copy
(
bias_g_n_k_wos_desc
.
GetStrides
(),
d0_g_n_k_wos_strides
);
copy
(
residual_g_n_k_wos_desc
.
GetLengths
(),
d1_g_n_k_wos_lengths
);
copy
(
residual_g_n_k_wos_desc
.
GetStrides
(),
d1_g_n_k_wos_strides
);
copy
(
out_g_n_k_wos_desc
.
GetLengths
(),
e_g_n_k_wos_lengths
);
copy
(
out_g_n_k_wos_desc
.
GetStrides
(),
e_g_n_k_wos_strides
);
copy
(
conv_param
.
conv_filter_strides_
,
conv_filter_strides
);
copy
(
conv_param
.
conv_filter_dilations_
,
conv_filter_dilations
);
copy
(
conv_param
.
input_left_pads_
,
input_left_pads
);
copy
(
conv_param
.
input_right_pads_
,
input_right_pads
);
using
ck
::
utils
::
to_array
;
// do Conv
auto
conv
=
DeviceConvNDFwdInstance
{};
auto
invoker
=
conv
.
MakeInvoker
();
auto
argument
=
conv
.
MakeArgument
(
in_device_buf
.
GetDeviceBuffer
(),
wei_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
2
>
{
bias_device_buf
.
GetDeviceBuffer
(),
residual_device_buf
.
GetDeviceBuffer
()},
out_device_buf
.
GetDeviceBuffer
(),
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
std
::
array
<
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
,
2
>
{
{
d0_g_n_k_wos_lengths
,
d1_g_n_k_wos_lengths
}},
std
::
array
<
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
,
2
>
{
{
d0_g_n_k_wos_strides
,
d1_g_n_k_wos_strides
}},
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
);
auto
conv
=
DeviceConvNDFwdInstance
{};
auto
invoker
=
conv
.
MakeInvoker
();
auto
argument
=
conv
.
MakeArgument
(
in_device_buf
.
GetDeviceBuffer
(),
wei_device_buf
.
GetDeviceBuffer
(),
to_array
({
bias_device_buf
.
GetDeviceBuffer
(),
residual_device_buf
.
GetDeviceBuffer
()}),
out_device_buf
.
GetDeviceBuffer
(),
to_array
(
in_g_n_c_wis_desc
.
GetLengths
()),
to_array
(
in_g_n_c_wis_desc
.
GetStrides
()),
to_array
(
wei_g_k_c_xs_desc
.
GetLengths
()),
to_array
(
wei_g_k_c_xs_desc
.
GetStrides
()),
to_array
({
d0_g_n_k_wos_lengths
,
d1_g_n_k_wos_lengths
}),
to_array
({
d0_g_n_k_wos_strides
,
d1_g_n_k_wos_strides
}),
to_array
(
out_g_n_k_wos_desc
.
GetLengths
()),
to_array
(
out_g_n_k_wos_desc
.
GetStrides
()),
to_array
(
conv_param
.
conv_filter_strides_
),
to_array
(
conv_param
.
conv_filter_dilations_
),
to_array
(
conv_param
.
input_left_pads_
),
to_array
(
conv_param
.
input_right_pads_
),
in_element_op
,
wei_element_op
,
out_element_op
);
if
(
!
conv
.
IsSupportedArgument
(
argument
))
{
...
...
@@ -209,21 +190,17 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification,
out_element_op
(
out_host
(
idx
),
c_host
(
idx
),
bias
(
idx
),
residual
(
idx
));
});
out_device_buf
.
FromDevice
(
out_device
.
mData
.
data
());
out_device_buf
.
FromDevice
(
out_device
.
data
());
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
const
Tensor
<
OutUserDataType
>
out_device_converted
(
out_device
);
return
ck
::
utils
::
check_err
(
out_device_converted
.
mData
,
out_host
.
mData
,
"Error: incorrect results!"
,
1e-5
f
,
1e-4
f
)
return
ck
::
utils
::
check_err
(
out_device_converted
,
out_host
,
"Error: incorrect results!"
,
1e-5
f
,
1e-4
f
)
?
0
:
1
;
#else // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
return
ck
::
utils
::
check_err
(
out_device
.
mData
,
out_host
.
mData
,
"Error: incorrect results!"
,
1e-5
f
,
1e-4
f
)
return
ck
::
utils
::
check_err
(
out_device
,
out_host
,
"Error: incorrect results!"
,
1e-5
f
,
1e-4
f
)
?
0
:
1
;
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
...
...
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_bf16.cpp
View file @
e4e99a49
...
...
@@ -9,32 +9,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
#include "common.hpp"
using
ADataType
=
BF16
;
using
B0DataType
=
BF16
;
...
...
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp
View file @
e4e99a49
...
...
@@ -9,32 +9,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
#include "common.hpp"
using
ADataType
=
F16
;
using
B0DataType
=
F16
;
...
...
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp32.cpp
View file @
e4e99a49
...
...
@@ -9,31 +9,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
#include "common.hpp"
using
ADataType
=
F32
;
using
B0DataType
=
F32
;
...
...
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp
View file @
e4e99a49
...
...
@@ -13,29 +13,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
#error Should compile this file with ck::int4_t support
#endif
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
#include "common.hpp"
using
ADataType
=
ck
::
int4_t
;
using
B0DataType
=
ck
::
int4_t
;
...
...
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int8.cpp
View file @
e4e99a49
...
...
@@ -9,29 +9,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
#include "common.hpp"
using
ADataType
=
int8_t
;
using
B0DataType
=
int8_t
;
...
...
example/31_batched_gemm_gemm/common.hpp
0 → 100644
View file @
e4e99a49
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
BF16
=
ck
::
bhalf_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc
View file @
e4e99a49
...
...
@@ -100,21 +100,21 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[])
BatchStrideB1
=
BatchStrideB1
<
0
?
DefaultBatchStrideB1
:
BatchStrideB1
;
BatchStrideC
=
BatchStrideC
<
0
?
DefaultBatchStrideC
:
BatchStrideC
;
using
namespace
ck
::
literals
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
batch_count
,
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
std
::
size_t
batch_stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
Row
>
::
value
)
if
constexpr
(
std
::
is_same
_v
<
decltype
(
layout
),
Row
>
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
batch_stride
,
stride
,
1
}));
return
HostTensorDescriptor
({
batch_count
,
row
,
col
},
{
batch_stride
,
stride
,
1_
uz
});
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
batch_stride
,
1
,
stride
}));
return
HostTensorDescriptor
({
batch_count
,
row
,
col
},
{
batch_stride
,
1_
uz
,
stride
});
}
};
...
...
@@ -130,10 +130,10 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[])
Tensor
<
CDataType
>
c_g_m_o_device_result
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
O
,
StrideC
,
BatchStrideC
,
CLayout
{}));
std
::
cout
<<
"a_g_m_k: "
<<
a_g_m_k
.
m
Desc
<<
std
::
endl
;
std
::
cout
<<
"b0_g_k_n: "
<<
b0_g_k_n
.
m
Desc
<<
std
::
endl
;
std
::
cout
<<
"b1_g_n_o: "
<<
b1_g_n_o
.
m
Desc
<<
std
::
endl
;
std
::
cout
<<
"c_g_m_o: "
<<
c_g_m_o_host_result
.
m
Desc
<<
std
::
endl
;
std
::
cout
<<
"a_g_m_k: "
<<
a_g_m_k
.
Get
Desc
()
<<
std
::
endl
;
std
::
cout
<<
"b0_g_k_n: "
<<
b0_g_k_n
.
Get
Desc
()
<<
std
::
endl
;
std
::
cout
<<
"b1_g_n_o: "
<<
b1_g_n_o
.
Get
Desc
()
<<
std
::
endl
;
std
::
cout
<<
"c_g_m_o: "
<<
c_g_m_o_host_result
.
Get
Desc
()
<<
std
::
endl
;
switch
(
init_method
)
{
...
...
@@ -155,29 +155,27 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[])
}
#ifdef BUILD_INT4_EXAMPLE
DeviceMem
a_g_m_k_device_buf
(
sizeof
(
KernelADataType
)
*
a_g_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b0_g_k_n_device_buf
(
sizeof
(
KernelB0DataType
)
*
b0_g_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b1_g_n_o_device_buf
(
sizeof
(
KernelB1DataType
)
*
b1_g_n_o
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_g_m_o_device_buf
(
sizeof
(
KernelCDataType
)
*
c_g_m_o_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_g_m_k_device_buf
(
a_g_m_k
.
GetMemorySize
());
DeviceMem
b0_g_k_n_device_buf
(
b0_g_k_n
.
GetMemorySize
());
DeviceMem
b1_g_n_o_device_buf
(
b1_g_n_o
.
GetMemorySize
());
DeviceMem
c_g_m_o_device_buf
(
c_g_m_o_device_result
.
GetMemorySize
());
const
Tensor
<
KernelADataType
>
a_g_m_k_converted
(
a_g_m_k
);
const
Tensor
<
KernelB0DataType
>
b0_g_k_n_converted
(
b0_g_k_n
);
const
Tensor
<
KernelB1DataType
>
b1_g_n_o_converted
(
b1_g_n_o
);
a_g_m_k_device_buf
.
ToDevice
(
a_g_m_k_converted
.
mData
.
data
());
b0_g_k_n_device_buf
.
ToDevice
(
b0_g_k_n_converted
.
mData
.
data
());
b1_g_n_o_device_buf
.
ToDevice
(
b1_g_n_o_converted
.
mData
.
data
());
a_g_m_k_device_buf
.
ToDevice
(
a_g_m_k_converted
.
data
());
b0_g_k_n_device_buf
.
ToDevice
(
b0_g_k_n_converted
.
data
());
b1_g_n_o_device_buf
.
ToDevice
(
b1_g_n_o_converted
.
data
());
#else
DeviceMem
a_g_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_g_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b0_g_k_n_device_buf
(
sizeof
(
B0DataType
)
*
b0_g_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b1_g_n_o_device_buf
(
sizeof
(
B1DataType
)
*
b1_g_n_o
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_g_m_o_device_buf
(
sizeof
(
CDataType
)
*
c_g_m_o_device_result
.
mDesc
.
GetElementSpaceSize
());
a_g_m_k_device_buf
.
ToDevice
(
a_g_m_k
.
mData
.
data
());
b0_g_k_n_device_buf
.
ToDevice
(
b0_g_k_n
.
mData
.
data
());
b1_g_n_o_device_buf
.
ToDevice
(
b1_g_n_o
.
mData
.
data
());
DeviceMem
a_g_m_k_device_buf
(
a_g_m_k
.
GetMemorySize
());
DeviceMem
b0_g_k_n_device_buf
(
b0_g_k_n
.
GetMemorySize
());
DeviceMem
b1_g_n_o_device_buf
(
b1_g_n_o
.
GetMemorySize
());
DeviceMem
c_g_m_o_device_buf
(
c_g_m_o_device_result
.
GetMemorySize
());
a_g_m_k_device_buf
.
ToDevice
(
a_g_m_k
.
data
());
b0_g_k_n_device_buf
.
ToDevice
(
b0_g_k_n
.
data
());
b1_g_n_o_device_buf
.
ToDevice
(
b1_g_n_o
.
data
());
#endif
auto
a_element_op
=
AElementOp
{};
...
...
@@ -189,36 +187,28 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[])
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
#ifdef BUILD_INT4_EXAMPLE
static_cast
<
KernelADataType
*>
(
a_g_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelB0DataType
*>
(
b0_g_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelB1DataType
*>
(
b1_g_n_o_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelCDataType
*>
(
c_g_m_o_device_buf
.
GetDeviceBuffer
()),
#else
static_cast
<
ADataType
*>
(
a_g_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_g_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_g_n_o_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_g_m_o_device_buf
.
GetDeviceBuffer
()),
#endif
M
,
N
,
K
,
O
,
BatchCount
,
StrideA
,
StrideB0
,
StrideB1
,
StrideC
,
BatchStrideA
,
BatchStrideB0
,
BatchStrideB1
,
BatchStrideC
,
a_element_op
,
b0_element_op
,
acc0_element_op
,
b1_element_op
,
c_element_op
);
auto
argument
=
gemm
.
MakeArgument
(
a_g_m_k_device_buf
.
GetDeviceBuffer
(),
b0_g_k_n_device_buf
.
GetDeviceBuffer
(),
b1_g_n_o_device_buf
.
GetDeviceBuffer
(),
c_g_m_o_device_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
O
,
BatchCount
,
StrideA
,
StrideB0
,
StrideB1
,
StrideC
,
BatchStrideA
,
BatchStrideB0
,
BatchStrideB1
,
BatchStrideC
,
a_element_op
,
b0_element_op
,
acc0_element_op
,
b1_element_op
,
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
...
...
@@ -261,16 +251,16 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[])
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
#ifdef BUILD_INT4_EXAMPLE
Tensor
<
KernelCDataType
>
c_g_m_o_device_result_converted
(
c_g_m_o_host_result
.
m
Desc
);
Tensor
<
KernelCDataType
>
c_g_m_o_device_result_converted
(
c_g_m_o_host_result
.
Get
Desc
()
);
c_g_m_o_device_buf
.
FromDevice
(
c_g_m_o_device_result_converted
.
mData
.
data
());
c_g_m_o_device_buf
.
FromDevice
(
c_g_m_o_device_result_converted
.
data
());
c_g_m_o_device_result
=
c_g_m_o_device_result_converted
.
CopyAsType
<
CDataType
>
();
#else
c_g_m_o_device_buf
.
FromDevice
(
c_g_m_o_device_result
.
mData
.
data
());
c_g_m_o_device_buf
.
FromDevice
(
c_g_m_o_device_result
.
data
());
#endif
return
ck
::
utils
::
check_err
(
c_g_m_o_device_result
.
mData
,
c_g_m_o_host_result
.
mData
);
return
ck
::
utils
::
check_err
(
c_g_m_o_device_result
,
c_g_m_o_host_result
);
}
return
true
;
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
View file @
e4e99a49
...
...
@@ -9,23 +9,24 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
Gemm1
*/
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/utility/literals.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
@@ -222,7 +223,9 @@ int main(int argc, char* argv[])
BatchStrideB0
=
BatchStrideB0
<
0
?
DefaultBatchStrideB0
:
BatchStrideB0
;
BatchStrideB1
=
BatchStrideB1
<
0
?
DefaultBatchStrideB1
:
BatchStrideB1
;
const
int
BatchCount
=
G0
*
G1
;
const
ck
::
index_t
BatchCount
=
G0
*
G1
;
using
namespace
ck
::
literals
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
batch_count
,
std
::
size_t
row
,
...
...
@@ -230,15 +233,13 @@ int main(int argc, char* argv[])
std
::
size_t
stride
,
std
::
size_t
batch_stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
Row
>
::
value
)
if
constexpr
(
std
::
is_same
_v
<
decltype
(
layout
),
Row
>
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
batch_stride
,
stride
,
1
}));
return
HostTensorDescriptor
({
batch_count
,
row
,
col
},
{
batch_stride
,
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
batch_stride
,
1
,
stride
}));
return
HostTensorDescriptor
({
batch_count
,
row
,
col
},
{
batch_stride
,
1
_uz
,
stride
});
}
};
...
...
@@ -249,17 +250,13 @@ int main(int argc, char* argv[])
f_host_tensor_descriptor
(
BatchCount
,
K
,
N
,
StrideB0
,
BatchStrideB0
,
B0Layout
{}));
Tensor
<
B1DataType
>
b1_g_n_o
(
f_host_tensor_descriptor
(
BatchCount
,
N
,
O
,
StrideB1
,
BatchStrideB1
,
B1Layout
{}));
Tensor
<
CDataType
>
c_gs_ms_os_host_result
(
std
::
vector
<
std
::
size_t
>
(
c_gs_ms_os_lengths
.
begin
(),
c_gs_ms_os_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
c_gs_ms_os_strides
.
begin
(),
c_gs_ms_os_strides
.
end
()));
Tensor
<
CDataType
>
c_gs_ms_os_device_result
(
std
::
vector
<
std
::
size_t
>
(
c_gs_ms_os_lengths
.
begin
(),
c_gs_ms_os_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
c_gs_ms_os_strides
.
begin
(),
c_gs_ms_os_strides
.
end
()));
std
::
cout
<<
"a_g_m_k: "
<<
a_g_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_g_k_n: "
<<
b0_g_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b1_g_n_o: "
<<
b1_g_n_o
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_gs_ms_os: "
<<
c_gs_ms_os_host_result
.
mDesc
<<
std
::
endl
;
Tensor
<
CDataType
>
c_gs_ms_os_host_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_device_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
std
::
cout
<<
"a_g_m_k: "
<<
a_g_m_k
.
GetDesc
()
<<
std
::
endl
;
std
::
cout
<<
"b0_g_k_n: "
<<
b0_g_k_n
.
GetDesc
()
<<
std
::
endl
;
std
::
cout
<<
"b1_g_n_o: "
<<
b1_g_n_o
.
GetDesc
()
<<
std
::
endl
;
std
::
cout
<<
"c_gs_ms_os: "
<<
c_gs_ms_os_host_result
.
GetDesc
()
<<
std
::
endl
;
switch
(
init_method
)
{
...
...
@@ -285,15 +282,14 @@ int main(int argc, char* argv[])
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
DeviceMem
a_g_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_g_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b0_g_k_n_device_buf
(
sizeof
(
B0DataType
)
*
b0_g_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b1_g_n_o_device_buf
(
sizeof
(
B1DataType
)
*
b1_g_n_o
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_gs_ms_os_device_buf
(
sizeof
(
CDataType
)
*
c_gs_ms_os_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_g_m_k_device_buf
(
a_g_m_k
.
GetMemorySize
());
DeviceMem
b0_g_k_n_device_buf
(
b0_g_k_n
.
GetMemorySize
());
DeviceMem
b1_g_n_o_device_buf
(
b1_g_n_o
.
GetMemorySize
());
DeviceMem
c_gs_ms_os_device_buf
(
c_gs_ms_os_device_result
.
GetMemorySize
());
a_g_m_k_device_buf
.
ToDevice
(
a_g_m_k
.
mData
.
data
());
b0_g_k_n_device_buf
.
ToDevice
(
b0_g_k_n
.
mData
.
data
());
b1_g_n_o_device_buf
.
ToDevice
(
b1_g_n_o
.
mData
.
data
());
a_g_m_k_device_buf
.
ToDevice
(
a_g_m_k
.
data
());
b0_g_k_n_device_buf
.
ToDevice
(
b0_g_k_n
.
data
());
b1_g_n_o_device_buf
.
ToDevice
(
b1_g_n_o
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b0_element_op
=
B0ElementOp
{};
...
...
@@ -302,31 +298,30 @@ int main(int argc, char* argv[])
auto
c_element_op
=
CElementOp
{};
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_g_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_g_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_g_n_o_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_gs_ms_os_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
O
,
BatchCount
,
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
,
StrideA
,
StrideB0
,
StrideB1
,
BatchStrideA
,
BatchStrideB0
,
BatchStrideB1
,
a_element_op
,
b0_element_op
,
acc0_element_op
,
b1_element_op
,
c_element_op
);
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
a_g_m_k_device_buf
.
GetDeviceBuffer
(),
b0_g_k_n_device_buf
.
GetDeviceBuffer
(),
b1_g_n_o_device_buf
.
GetDeviceBuffer
(),
c_gs_ms_os_device_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
O
,
BatchCount
,
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
,
StrideA
,
StrideB0
,
StrideB1
,
BatchStrideA
,
BatchStrideB0
,
BatchStrideB1
,
a_element_op
,
b0_element_op
,
acc0_element_op
,
b1_element_op
,
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
...
...
@@ -351,15 +346,14 @@ int main(int argc, char* argv[])
if
(
do_verification
)
{
c_gs_ms_os_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
mData
.
data
());
c_gs_ms_os_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
data
());
// Output of Gemm0 is input A of Gemm1
Tensor
<
AccDataType
>
acc0_g_m_n
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
N
,
M
*
N
,
Row
{}));
Tensor
<
ADataType
>
a1_g_m_n
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
N
,
M
*
N
,
Row
{}));
Tensor
<
CDataType
>
c_g_m_o_host_result
(
std
::
vector
<
int
>
{
BatchCount
,
M
,
O
},
std
::
vector
<
int
>
{
M
*
O
,
O
,
1
});
Tensor
<
CDataType
>
c_g_m_o_host_result
({
BatchCount
,
M
,
O
},
{
M
*
O
,
O
,
1
});
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
...
...
@@ -400,9 +394,7 @@ int main(int argc, char* argv[])
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
});
return
ck
::
utils
::
check_err
(
c_gs_ms_os_device_result
.
mData
,
c_gs_ms_os_host_result
.
mData
)
?
0
:
1
;
return
ck
::
utils
::
check_err
(
c_gs_ms_os_device_result
,
c_gs_ms_os_host_result
)
?
0
:
1
;
}
return
0
;
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
View file @
e4e99a49
...
...
@@ -9,23 +9,24 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
Gemm1
*/
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/utility/literals.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
@@ -222,7 +223,9 @@ int main(int argc, char* argv[])
BatchStrideB0
=
BatchStrideB0
<
0
?
DefaultBatchStrideB0
:
BatchStrideB0
;
BatchStrideB1
=
BatchStrideB1
<
0
?
DefaultBatchStrideB1
:
BatchStrideB1
;
const
int
BatchCount
=
G0
*
G1
;
const
ck
::
index_t
BatchCount
=
G0
*
G1
;
using
namespace
ck
::
literals
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
batch_count
,
std
::
size_t
row
,
...
...
@@ -230,15 +233,13 @@ int main(int argc, char* argv[])
std
::
size_t
stride
,
std
::
size_t
batch_stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
Row
>
::
value
)
if
constexpr
(
std
::
is_same
_v
<
decltype
(
layout
),
Row
>
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
batch_stride
,
stride
,
1
}));
return
HostTensorDescriptor
({
batch_count
,
row
,
col
},
{
batch_stride
,
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
batch_stride
,
1
,
stride
}));
return
HostTensorDescriptor
({
batch_count
,
row
,
col
},
{
batch_stride
,
1
_uz
,
stride
});
}
};
...
...
@@ -249,17 +250,13 @@ int main(int argc, char* argv[])
f_host_tensor_descriptor
(
BatchCount
,
K
,
N
,
StrideB0
,
BatchStrideB0
,
B0Layout
{}));
Tensor
<
B1DataType
>
b1_g_n_o
(
f_host_tensor_descriptor
(
BatchCount
,
N
,
O
,
StrideB1
,
BatchStrideB1
,
B1Layout
{}));
Tensor
<
CDataType
>
c_gs_ms_os_host_result
(
std
::
vector
<
std
::
size_t
>
(
c_gs_ms_os_lengths
.
begin
(),
c_gs_ms_os_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
c_gs_ms_os_strides
.
begin
(),
c_gs_ms_os_strides
.
end
()));
Tensor
<
CDataType
>
c_gs_ms_os_device_result
(
std
::
vector
<
std
::
size_t
>
(
c_gs_ms_os_lengths
.
begin
(),
c_gs_ms_os_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
c_gs_ms_os_strides
.
begin
(),
c_gs_ms_os_strides
.
end
()));
std
::
cout
<<
"a_g_m_k: "
<<
a_g_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_g_k_n: "
<<
b0_g_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b1_g_n_o: "
<<
b1_g_n_o
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_gs_ms_os: "
<<
c_gs_ms_os_host_result
.
mDesc
<<
std
::
endl
;
Tensor
<
CDataType
>
c_gs_ms_os_host_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_device_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
std
::
cout
<<
"a_g_m_k: "
<<
a_g_m_k
.
GetDesc
()
<<
std
::
endl
;
std
::
cout
<<
"b0_g_k_n: "
<<
b0_g_k_n
.
GetDesc
()
<<
std
::
endl
;
std
::
cout
<<
"b1_g_n_o: "
<<
b1_g_n_o
.
GetDesc
()
<<
std
::
endl
;
std
::
cout
<<
"c_gs_ms_os: "
<<
c_gs_ms_os_host_result
.
GetDesc
()
<<
std
::
endl
;
switch
(
init_method
)
{
...
...
@@ -285,15 +282,14 @@ int main(int argc, char* argv[])
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
DeviceMem
a_g_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_g_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b0_g_k_n_device_buf
(
sizeof
(
B0DataType
)
*
b0_g_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b1_g_n_o_device_buf
(
sizeof
(
B1DataType
)
*
b1_g_n_o
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_gs_ms_os_device_buf
(
sizeof
(
CDataType
)
*
c_gs_ms_os_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_g_m_k_device_buf
(
a_g_m_k
.
GetMemorySize
());
DeviceMem
b0_g_k_n_device_buf
(
b0_g_k_n
.
GetMemorySize
());
DeviceMem
b1_g_n_o_device_buf
(
b1_g_n_o
.
GetMemorySize
());
DeviceMem
c_gs_ms_os_device_buf
(
c_gs_ms_os_device_result
.
GetMemorySize
());
a_g_m_k_device_buf
.
ToDevice
(
a_g_m_k
.
mData
.
data
());
b0_g_k_n_device_buf
.
ToDevice
(
b0_g_k_n
.
mData
.
data
());
b1_g_n_o_device_buf
.
ToDevice
(
b1_g_n_o
.
mData
.
data
());
a_g_m_k_device_buf
.
ToDevice
(
a_g_m_k
.
data
());
b0_g_k_n_device_buf
.
ToDevice
(
b0_g_k_n
.
data
());
b1_g_n_o_device_buf
.
ToDevice
(
b1_g_n_o
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b0_element_op
=
B0ElementOp
{};
...
...
@@ -302,31 +298,30 @@ int main(int argc, char* argv[])
auto
c_element_op
=
CElementOp
{};
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_g_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_g_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_g_n_o_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_gs_ms_os_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
O
,
BatchCount
,
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
,
StrideA
,
StrideB0
,
StrideB1
,
BatchStrideA
,
BatchStrideB0
,
BatchStrideB1
,
a_element_op
,
b0_element_op
,
acc0_element_op
,
b1_element_op
,
c_element_op
);
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
a_g_m_k_device_buf
.
GetDeviceBuffer
(),
b0_g_k_n_device_buf
.
GetDeviceBuffer
(),
b1_g_n_o_device_buf
.
GetDeviceBuffer
(),
c_gs_ms_os_device_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
O
,
BatchCount
,
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
,
StrideA
,
StrideB0
,
StrideB1
,
BatchStrideA
,
BatchStrideB0
,
BatchStrideB1
,
a_element_op
,
b0_element_op
,
acc0_element_op
,
b1_element_op
,
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
...
...
@@ -351,15 +346,14 @@ int main(int argc, char* argv[])
if
(
do_verification
)
{
c_gs_ms_os_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
mData
.
data
());
c_gs_ms_os_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
data
());
// Output of Gemm0 is input A of Gemm1
Tensor
<
AccDataType
>
acc0_g_m_n
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
N
,
M
*
N
,
Row
{}));
Tensor
<
ADataType
>
a1_g_m_n
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
N
,
M
*
N
,
Row
{}));
Tensor
<
CDataType
>
c_g_m_o_host_result
(
std
::
vector
<
int
>
{
BatchCount
,
M
,
O
},
std
::
vector
<
int
>
{
M
*
O
,
O
,
1
});
Tensor
<
CDataType
>
c_g_m_o_host_result
({
BatchCount
,
M
,
O
},
{
M
*
O
,
O
,
1
});
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
...
...
@@ -390,9 +384,7 @@ int main(int argc, char* argv[])
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
});
return
ck
::
utils
::
check_err
(
c_gs_ms_os_device_result
.
mData
,
c_gs_ms_os_host_result
.
mData
)
?
0
:
1
;
return
ck
::
utils
::
check_err
(
c_gs_ms_os_device_result
,
c_gs_ms_os_host_result
)
?
0
:
1
;
}
return
0
;
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp
View file @
e4e99a49
...
...
@@ -9,22 +9,23 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
Gemm1
*/
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/utility/literals.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
@@ -239,21 +240,21 @@ int main(int argc, char* argv[])
BatchStrideB1
=
BatchStrideB1
<
0
?
DefaultBatchStrideB1
:
BatchStrideB1
;
BatchStrideC
=
BatchStrideC
<
0
?
DefaultBatchStrideC
:
BatchStrideC
;
using
namespace
ck
::
literals
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
batch_count
,
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
std
::
size_t
batch_stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
Row
>
::
value
)
if
constexpr
(
std
::
is_same
_v
<
decltype
(
layout
),
Row
>
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
batch_stride
,
stride
,
1
}));
return
HostTensorDescriptor
({
batch_count
,
row
,
col
},
{
batch_stride
,
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
batch_stride
,
1
,
stride
}));
return
HostTensorDescriptor
({
batch_count
,
row
,
col
},
{
batch_stride
,
1
_uz
,
stride
});
}
};
...
...
@@ -269,10 +270,10 @@ int main(int argc, char* argv[])
Tensor
<
CDataType
>
c_g_m_o_device_result
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
O
,
StrideC
,
BatchStrideC
,
CLayout
{}));
std
::
cout
<<
"a_g_m_k: "
<<
a_g_m_k
.
m
Desc
<<
std
::
endl
;
std
::
cout
<<
"b0_g_k_n: "
<<
b0_g_k_n
.
m
Desc
<<
std
::
endl
;
std
::
cout
<<
"b1_g_n_o: "
<<
b1_g_n_o
.
m
Desc
<<
std
::
endl
;
std
::
cout
<<
"c_g_m_o: "
<<
c_g_m_o_host_result
.
m
Desc
<<
std
::
endl
;
std
::
cout
<<
"a_g_m_k: "
<<
a_g_m_k
.
Get
Desc
()
<<
std
::
endl
;
std
::
cout
<<
"b0_g_k_n: "
<<
b0_g_k_n
.
Get
Desc
()
<<
std
::
endl
;
std
::
cout
<<
"b1_g_n_o: "
<<
b1_g_n_o
.
Get
Desc
()
<<
std
::
endl
;
std
::
cout
<<
"c_g_m_o: "
<<
c_g_m_o_host_result
.
Get
Desc
()
<<
std
::
endl
;
switch
(
init_method
)
{
...
...
@@ -298,15 +299,14 @@ int main(int argc, char* argv[])
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
DeviceMem
a_g_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_g_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b0_g_k_n_device_buf
(
sizeof
(
B0DataType
)
*
b0_g_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b1_g_n_o_device_buf
(
sizeof
(
B1DataType
)
*
b1_g_n_o
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_g_m_o_device_buf
(
sizeof
(
CDataType
)
*
c_g_m_o_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_g_m_k_device_buf
(
a_g_m_k
.
GetMemorySize
());
DeviceMem
b0_g_k_n_device_buf
(
b0_g_k_n
.
GetMemorySize
());
DeviceMem
b1_g_n_o_device_buf
(
b1_g_n_o
.
GetMemorySize
());
DeviceMem
c_g_m_o_device_buf
(
c_g_m_o_device_result
.
GetMemorySize
());
a_g_m_k_device_buf
.
ToDevice
(
a_g_m_k
.
mData
.
data
());
b0_g_k_n_device_buf
.
ToDevice
(
b0_g_k_n
.
mData
.
data
());
b1_g_n_o_device_buf
.
ToDevice
(
b1_g_n_o
.
mData
.
data
());
a_g_m_k_device_buf
.
ToDevice
(
a_g_m_k
.
data
());
b0_g_k_n_device_buf
.
ToDevice
(
b0_g_k_n
.
data
());
b1_g_n_o_device_buf
.
ToDevice
(
b1_g_n_o
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b0_element_op
=
B0ElementOp
{};
...
...
@@ -315,31 +315,30 @@ int main(int argc, char* argv[])
auto
c_element_op
=
CElementOp
{};
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_g_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_g_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_g_n_o_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_g_m_o_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
O
,
BatchCount
,
StrideA
,
StrideB0
,
StrideB1
,
StrideC
,
BatchStrideA
,
BatchStrideB0
,
BatchStrideB1
,
BatchStrideC
,
a_element_op
,
b0_element_op
,
acc0_element_op
,
b1_element_op
,
c_element_op
);
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
a_g_m_k_device_buf
.
GetDeviceBuffer
(),
b0_g_k_n_device_buf
.
GetDeviceBuffer
(),
b1_g_n_o_device_buf
.
GetDeviceBuffer
(),
c_g_m_o_device_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
O
,
BatchCount
,
StrideA
,
StrideB0
,
StrideB1
,
StrideC
,
BatchStrideA
,
BatchStrideB0
,
BatchStrideB1
,
BatchStrideC
,
a_element_op
,
b0_element_op
,
acc0_element_op
,
b1_element_op
,
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
...
...
@@ -362,7 +361,7 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
c_g_m_o_device_buf
.
FromDevice
(
c_g_m_o_device_result
.
mData
.
data
());
c_g_m_o_device_buf
.
FromDevice
(
c_g_m_o_device_result
.
data
());
if
(
do_verification
)
{
...
...
@@ -391,7 +390,7 @@ int main(int argc, char* argv[])
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
return
ck
::
utils
::
check_err
(
c_g_m_o_device_result
.
mData
,
c_g_m_o_host_result
.
mData
)
?
0
:
1
;
return
ck
::
utils
::
check_err
(
c_g_m_o_device_result
,
c_g_m_o_host_result
)
?
0
:
1
;
}
return
0
;
...
...
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment