Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c6891e12
Commit
c6891e12
authored
Jul 01, 2022
by
rocking
Browse files
Merge branch 'develop' into standalone-layernorm
parents
f591ad27
8e374781
Changes
296
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
631 additions
and
532 deletions
+631
-532
profiler/include/profile_gemm_reduce_impl.hpp
profiler/include/profile_gemm_reduce_impl.hpp
+7
-8
profiler/include/profile_gemm_splitk_impl.hpp
profiler/include/profile_gemm_splitk_impl.hpp
+16
-15
profiler/include/profile_grouped_gemm_impl.hpp
profiler/include/profile_grouped_gemm_impl.hpp
+11
-9
profiler/include/profile_normalization_impl.hpp
profiler/include/profile_normalization_impl.hpp
+243
-0
profiler/include/profile_reduce_impl.hpp
profiler/include/profile_reduce_impl.hpp
+4
-4
profiler/src/profile_batched_gemm.cpp
profiler/src/profile_batched_gemm.cpp
+14
-3
profiler/src/profile_gemm_add_add_fastgelu.cpp
profiler/src/profile_gemm_add_add_fastgelu.cpp
+10
-16
profiler/src/profile_normalization.cpp
profiler/src/profile_normalization.cpp
+134
-0
profiler/src/profiler.cpp
profiler/src/profiler.cpp
+6
-0
script/docker-rocm4.1.sh
script/docker-rocm4.1.sh
+0
-14
script/docker-rocm4.3.1.sh
script/docker-rocm4.3.1.sh
+0
-14
test/batched_gemm/batched_gemm_fp16.cpp
test/batched_gemm/batched_gemm_fp16.cpp
+4
-4
test/conv2d_bwd_data/conv2d_bwd_data.cpp
test/conv2d_bwd_data/conv2d_bwd_data.cpp
+6
-6
test/convnd_fwd/conv_util.hpp
test/convnd_fwd/conv_util.hpp
+6
-6
test/gemm/CMakeLists.txt
test/gemm/CMakeLists.txt
+12
-26
test/gemm/gemm_bf16.cpp
test/gemm/gemm_bf16.cpp
+79
-0
test/gemm/gemm_dl_fp16.cpp
test/gemm/gemm_dl_fp16.cpp
+0
-137
test/gemm/gemm_dl_fp32.cpp
test/gemm/gemm_dl_fp32.cpp
+0
-135
test/gemm/gemm_dl_int8.cpp
test/gemm/gemm_dl_int8.cpp
+0
-135
test/gemm/gemm_fp16.cpp
test/gemm/gemm_fp16.cpp
+79
-0
No files found.
profiler/include/profile_gemm_reduce_impl.hpp
View file @
c6891e12
...
...
@@ -19,7 +19,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_
instance
{
namespace
instance
{
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
...
...
@@ -45,7 +45,7 @@ void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances(
void
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmReduceNoOpPtr
>&
);
}
// namespace
device_gemm_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -204,8 +204,7 @@ bool profile_gemm_reduce_impl(int do_verification,
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
DeviceGemmReduceNoOpPtr
>
gemm_ptrs
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
instance
::
DeviceGemmReduceNoOpPtr
>
gemm_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
...
...
@@ -214,7 +213,7 @@ bool profile_gemm_reduce_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances
(
gemm_ptrs
);
}
...
...
@@ -222,7 +221,7 @@ bool profile_gemm_reduce_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances
(
gemm_ptrs
);
}
...
...
@@ -230,7 +229,7 @@ bool profile_gemm_reduce_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances
(
gemm_ptrs
);
}
...
...
@@ -238,7 +237,7 @@ bool profile_gemm_reduce_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances
(
gemm_ptrs
);
}
...
...
profiler/include/profile_gemm_splitk_impl.hpp
View file @
c6891e12
...
...
@@ -12,7 +12,7 @@
#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/
device_
gemm_splitk
_instance
.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
...
...
@@ -95,20 +95,21 @@ bool profile_gemm_splitk_impl(int do_verification,
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c_device_buf
.
ToDevice
(
c_m_n_device_result
.
mData
.
data
());
// add device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
get_device_gemm_splitk_instances
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
();
if
(
op_ptrs
.
size
()
<=
0
)
{
throw
std
::
runtime_error
(
"wrong! no device operation instance found"
);
}
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemmSplitK
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
// Run reference GEMM
if
(
do_verification
)
...
...
profiler/include/profile_grouped_gemm_impl.hpp
View file @
c6891e12
...
...
@@ -20,7 +20,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_grouped_gemm_
instance
{
namespace
instance
{
using
DeviceGroupedGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -36,7 +36,7 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGroupedGemmNoOpPtr
>&
);
}
// namespace
device_grouped_gemm_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -171,9 +171,7 @@ void profile_grouped_gemm_impl(int do_verification,
}
// add device GEMM instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
device_grouped_gemm_instance
::
DeviceGroupedGemmNoOpPtr
>
gemm_ptrs
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
instance
::
DeviceGroupedGemmNoOpPtr
>
gemm_ptrs
;
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
...
...
@@ -182,28 +180,28 @@ void profile_grouped_gemm_impl(int do_verification,
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_grouped_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_grouped_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_grouped_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
gemm_ptrs
);
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
ck
::
tensor_operation
::
device
::
device_grouped_gemm_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
gemm_ptrs
);
}
}
...
...
@@ -232,6 +230,10 @@ void profile_grouped_gemm_impl(int do_verification,
auto
invoker_ptr
=
gemm_ptr
->
MakeInvokerPointer
();
DeviceMem
gemm_desc_workspace
(
gemm_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
()));
gemm_ptr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
gemm_desc_workspace
.
GetDeviceBuffer
());
if
(
gemm_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
std
::
string
gemm_name
=
gemm_ptr
->
GetTypeString
();
...
...
profiler/include/profile_normalization_impl.hpp
0 → 100644
View file @
c6891e12
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/conv_util.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_softmax_f16_f16_rank3_instances
(
std
::
vector
<
DeviceNormalizationPtr
>&
);
void
add_device_softmax_f16_f16_rank4_instances
(
std
::
vector
<
DeviceNormalizationPtr
>&
);
void
add_device_softmax_f32_f32_rank3_instances
(
std
::
vector
<
DeviceNormalizationPtr
>&
);
void
add_device_softmax_f32_f32_rank4_instances
(
std
::
vector
<
DeviceNormalizationPtr
>&
);
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
namespace
ck
{
namespace
profiler
{
enum
struct
NormType
{
LAYERNORM
,
BATCHNORM
,
SOFTMAX
,
};
enum
struct
NormDataType
{
F32_F32
,
// in, out
F16_F16
,
BF16_BF16
,
INT8_INT8
,
};
// clang-format off
template
<
typename
NormDataType
>
std
::
string
type_to_string
();
template
<
>
std
::
string
type_to_string
<
float
>
()
{
return
"f32"
;
}
template
<
>
std
::
string
type_to_string
<
half_t
>
()
{
return
"f16"
;
}
template
<
>
std
::
string
type_to_string
<
bhalf_t
>
()
{
return
"bf16"
;
}
template
<
>
std
::
string
type_to_string
<
int8_t
>
()
{
return
"int8"
;
}
template
<
>
std
::
string
type_to_string
<
int32_t
>
()
{
return
"int32"
;
}
// clang-format on
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
>
void
profile_normalization_impl
(
int
do_verification
,
int
init_method
,
bool
do_log
,
bool
time_kernel
,
std
::
vector
<
index_t
>
in_length
,
std
::
vector
<
index_t
>
in_strides
,
std
::
vector
<
index_t
>
reduce_dims
,
AccDataType
alpha
,
AccDataType
beta
,
NormType
norm_type
)
{
Tensor
<
InDataType
>
in
=
in_strides
.
empty
()
?
Tensor
<
InDataType
>
(
in_length
)
:
Tensor
<
InDataType
>
(
in_length
,
in_strides
);
Tensor
<
OutDataType
>
out
(
in
.
mDesc
);
switch
(
init_method
)
{
// case 0: break;
case
0
:
in
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{});
out
.
GenerateTensorValue
(
GeneratorTensor_1
<
OutDataType
>
{});
break
;
case
1
:
in
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
});
out
.
GenerateTensorValue
(
GeneratorTensor_2
<
OutDataType
>
{
-
5
,
5
});
break
;
default:
in
.
GenerateTensorValue
(
GeneratorTensor_3
<
InDataType
>
{
0.0
,
1.0
});
out
.
GenerateTensorValue
(
GeneratorTensor_3
<
OutDataType
>
{
-
0.5
,
0.5
});
}
Tensor
<
OutDataType
>
out_ref
(
out
);
DeviceMem
in_dev
(
sizeof
(
InDataType
)
*
in
.
mDesc
.
GetElementSpace
());
DeviceMem
out_dev
(
sizeof
(
OutDataType
)
*
out
.
mDesc
.
GetElementSpace
());
in_dev
.
ToDevice
(
in
.
mData
.
data
());
out_dev
.
ToDevice
(
out
.
mData
.
data
());
std
::
vector
<
index_t
>
i_in_lengths
(
in
.
mDesc
.
GetLengths
().
begin
(),
in
.
mDesc
.
GetLengths
().
end
());
std
::
vector
<
index_t
>
i_in_strides
(
in
.
mDesc
.
GetStrides
().
begin
(),
in
.
mDesc
.
GetStrides
().
end
());
// add device normalization instances
std
::
vector
<
tensor_operation
::
device
::
DeviceNormalizationPtr
>
instances
;
if
(
norm_type
==
NormType
::
SOFTMAX
)
{
if
constexpr
(
is_same
<
InDataType
,
half_t
>::
value
&&
is_same
<
OutDataType
,
half_t
>::
value
&&
is_same
<
AccDataType
,
float
>::
value
)
{
if
(
in_length
.
size
()
==
3
)
tensor_operation
::
device
::
instance
::
add_device_softmax_f16_f16_rank3_instances
(
instances
);
if
(
in_length
.
size
()
==
4
)
tensor_operation
::
device
::
instance
::
add_device_softmax_f16_f16_rank4_instances
(
instances
);
}
else
if
constexpr
(
is_same
<
InDataType
,
float
>::
value
&&
is_same
<
OutDataType
,
float
>::
value
&&
is_same
<
AccDataType
,
float
>::
value
)
{
if
(
in_length
.
size
()
==
3
)
tensor_operation
::
device
::
instance
::
add_device_softmax_f32_f32_rank3_instances
(
instances
);
if
(
in_length
.
size
()
==
4
)
tensor_operation
::
device
::
instance
::
add_device_softmax_f32_f32_rank4_instances
(
instances
);
}
}
if
(
instances
.
size
()
<=
0
)
{
throw
std
::
runtime_error
(
"wrong! no device normalization instance found"
);
}
std
::
string
best_instance_name
;
float
best_avg_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_gb_per_sec
=
0
;
for
(
auto
&
inst_ptr
:
instances
)
{
// Is this user's responsibility to check if problem mismatches kernel instance (ie. rank 3
// problem to rank 4 kernel) other than invoking IsSupportedArgument()?
if
(
!
(
inst_ptr
->
GetRank
()
==
static_cast
<
index_t
>
(
i_in_lengths
.
size
())
&&
inst_ptr
->
GetNumReduceDim
()
==
static_cast
<
index_t
>
(
reduce_dims
.
size
())))
{
continue
;
}
auto
argument_ptr
=
inst_ptr
->
MakeArgumentPointer
(
i_in_lengths
,
i_in_strides
,
reduce_dims
,
&
alpha
,
&
beta
,
in_dev
.
GetDeviceBuffer
(),
out_dev
.
GetDeviceBuffer
());
if
(
!
inst_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
std
::
cout
<<
inst_ptr
->
GetTypeString
()
<<
" skipped due to unsupported argument: "
;
LogRange
(
std
::
cout
<<
"input lengths = ["
,
in_length
,
", "
)
<<
"], "
<<
"scaler = ["
<<
alpha
<<
", "
<<
beta
<<
"]."
<<
std
::
endl
;
return
;
}
auto
invoker_ptr
=
inst_ptr
->
MakeInvokerPointer
();
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
num_bytes
=
in
.
mDesc
.
GetElementSize
()
*
sizeof
(
InDataType
)
+
(
beta
==
0.0
f
?
1
:
2
)
*
out
.
mDesc
.
GetElementSize
()
*
sizeof
(
OutDataType
);
float
gb_per_sec
=
num_bytes
/
1.E6
/
avg_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
avg_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
<<
inst_ptr
->
GetTypeString
()
<<
std
::
endl
;
if
(
avg_time
<
best_avg_time
)
{
best_instance_name
=
inst_ptr
->
GetTypeString
();
best_avg_time
=
avg_time
;
best_gb_per_sec
=
gb_per_sec
;
}
if
(
do_verification
)
{
// TODO: factory method to dynamically switch between different reference normalizations
using
ReferenceFactory
=
tensor_operation
::
host
::
ReferenceSoftmax
<
InDataType
,
OutDataType
,
AccDataType
>
;
ReferenceFactory
{}.
MakeInvoker
().
Run
({
in
,
out_ref
,
alpha
,
beta
,
reduce_dims
});
out_dev
.
FromDevice
(
out
.
mData
.
data
());
bool
pass
;
if
(
std
::
is_same
<
InDataType
,
int8_t
>::
value
)
{
pass
=
ck
::
utils
::
check_err
(
out
.
mData
,
out_ref
.
mData
,
"Error: Incorrect results!"
,
0
,
1
);
if
(
do_log
)
{
LogRangeAsType
<
int
>
(
std
::
cout
<<
"in : "
,
in
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
int
>
(
std
::
cout
<<
"out_ref : "
,
out_ref
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
int
>
(
std
::
cout
<<
"out : "
,
out
.
mData
,
","
)
<<
std
::
endl
;
}
}
else
{
pass
=
ck
::
utils
::
check_err
(
out
.
mData
,
out_ref
.
mData
);
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"in : "
,
in
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"out_ref : "
,
out_ref
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"out : "
,
out
.
mData
,
","
)
<<
std
::
endl
;
}
}
if
(
!
pass
)
{
std
::
cout
<<
inst_ptr
->
GetTypeString
()
<<
" failed verification: "
;
LogRange
(
std
::
cout
<<
"input lengths = ["
,
in_length
,
", "
)
<<
"], "
<<
"scaler = ["
<<
alpha
<<
", "
<<
beta
<<
"]."
<<
std
::
endl
;
}
}
}
std
::
cout
<<
"Best Perf for datatype = "
<<
type_to_string
<
InDataType
>
()
<<
"_"
<<
type_to_string
<
OutDataType
>
()
<<
", "
;
LogRange
(
std
::
cout
<<
"length = "
,
i_in_lengths
,
","
)
<<
", "
;
LogRange
(
std
::
cout
<<
"stride = "
,
i_in_strides
,
","
)
<<
", "
;
LogRange
(
std
::
cout
<<
"reduce dims "
,
reduce_dims
,
","
)
<<
", "
;
std
::
cout
<<
"alpha = "
<<
alpha
<<
", "
<<
"beta = "
<<
beta
<<
", "
<<
best_avg_time
<<
" ms, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_instance_name
<<
std
::
endl
;
}
}
// namespace profiler
}
// namespace ck
profiler/include/profile_reduce_impl.hpp
View file @
c6891e12
...
...
@@ -16,7 +16,7 @@
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_reduce_
instance
{
namespace
instance
{
template
<
int
Rank
,
int
NumReduceDim
,
int
ReduceOpId
,
bool
PropagateNan
,
bool
UseIndex
>
struct
ReduceDescription
...
...
@@ -91,7 +91,7 @@ bool description_match(const DescriptionType& description,
return
(
result
);
};
}
// namespace
device_reduce_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -142,7 +142,7 @@ bool profile_reduce_impl_impl(bool do_verification,
float
beta
)
{
using
namespace
ck
::
tensor_operation
::
device
;
using
namespace
ck
::
tensor_operation
::
device
::
device_reduce_
instance
;
using
namespace
ck
::
tensor_operation
::
device
::
instance
;
using
ck
::
host_common
::
dumpBufferToFile
;
constexpr
bool
op_support_indices
=
...
...
@@ -464,7 +464,7 @@ bool profile_reduce_impl(bool do_verification,
bool
pass
=
true
;
using
tuple_of_description_instances
=
tensor_operation
::
device
::
device_reduce_
instance
::
reduce_description_instances
;
tensor_operation
::
device
::
instance
::
reduce_description_instances
;
const
auto
tuple_object
=
tuple_of_description_instances
{};
...
...
profiler/src/profile_batched_gemm.cpp
View file @
c6891e12
...
...
@@ -86,6 +86,14 @@ int profile_batched_gemm(int argc, char* argv[])
const
int
DefaultStrideB
=
ck
::
is_same_v
<
BLayout
,
Row
>
?
N
:
K
;
const
int
DefaultStrideC
=
ck
::
is_same_v
<
CLayout
,
Row
>
?
N
:
M
;
const
int
StrideA_
=
(
StrideA
<
0
)
?
DefaultStrideA
:
StrideA
;
const
int
StrideB_
=
(
StrideB
<
0
)
?
DefaultStrideB
:
StrideB
;
const
int
StrideC_
=
(
StrideC
<
0
)
?
DefaultStrideC
:
StrideC
;
const
int
BatchStrideA
=
(
ck
::
is_same_v
<
ALayout
,
Row
>
?
M
:
K
)
*
StrideA_
;
const
int
BatchStrideB
=
(
ck
::
is_same_v
<
BLayout
,
Row
>
?
K
:
N
)
*
StrideB_
;
const
int
BatchStrideC
=
(
ck
::
is_same_v
<
CLayout
,
Row
>
?
M
:
N
)
*
StrideC_
;
bool
pass
=
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
do_verification
,
...
...
@@ -95,9 +103,12 @@ int profile_batched_gemm(int argc, char* argv[])
M
,
N
,
K
,
(
StrideA
<
0
)
?
DefaultStrideA
:
StrideA
,
(
StrideB
<
0
)
?
DefaultStrideB
:
StrideB
,
(
StrideC
<
0
)
?
DefaultStrideC
:
StrideC
,
BatchStrideA
,
BatchStrideB
,
BatchStrideC
,
StrideA_
,
StrideB_
,
StrideC_
,
BatchCount
);
return
pass
?
0
:
1
;
...
...
profiler/src/profile_gemm_add_add_fastgelu.cpp
View file @
c6891e12
...
...
@@ -75,9 +75,7 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
auto
e_type
,
auto
a_layout
,
auto
b_layout
,
auto
d0_layout
,
auto
d1_layout
,
auto
e_layout
)
{
auto
de_layout
)
{
using
ADataType
=
decltype
(
a_type
);
using
BDataType
=
decltype
(
b_type
);
using
AccDataType
=
decltype
(
acc_type
);
...
...
@@ -87,15 +85,13 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
using
ALayout
=
decltype
(
a_layout
);
using
BLayout
=
decltype
(
b_layout
);
using
D0Layout
=
decltype
(
d0_layout
);
using
D1Layout
=
decltype
(
d1_layout
);
using
ELayout
=
decltype
(
e_layout
);
using
DELayout
=
decltype
(
de_layout
);
const
int
DefaultStrideA
=
ck
::
is_same_v
<
ALayout
,
Row
>
?
K
:
M
;
const
int
DefaultStrideB
=
ck
::
is_same_v
<
BLayout
,
Row
>
?
N
:
K
;
const
int
DefaultStrideD0
=
ck
::
is_same_v
<
D
0
Layout
,
Row
>
?
N
:
M
;
const
int
DefaultStrideD1
=
ck
::
is_same_v
<
D
1
Layout
,
Row
>
?
N
:
M
;
const
int
DefaultStrideE
=
ck
::
is_same_v
<
ELayout
,
Row
>
?
N
:
M
;
const
int
DefaultStrideD0
=
ck
::
is_same_v
<
D
E
Layout
,
Row
>
?
N
:
M
;
const
int
DefaultStrideD1
=
ck
::
is_same_v
<
D
E
Layout
,
Row
>
?
N
:
M
;
const
int
DefaultStrideE
=
ck
::
is_same_v
<
D
ELayout
,
Row
>
?
N
:
M
;
bool
pass
=
ck
::
profiler
::
profile_gemm_add_add_fastgelu_impl
<
ADataType
,
BDataType
,
...
...
@@ -105,9 +101,7 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
EDataType
,
ALayout
,
BLayout
,
D0Layout
,
D1Layout
,
ELayout
>
(
DELayout
>
(
do_verification
,
init_method
,
do_log
,
...
...
@@ -126,22 +120,22 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
if
(
data_type
==
MatrixDataType
::
F16_F16_F16_F16_F16
&&
layout
==
MatrixLayout
::
MK_KN_MN_MN_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
F16
{},
F16
{},
Row
{},
Row
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
F16
{},
F16
{},
Row
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
MatrixDataType
::
F16_F16_F16_F16_F16
&&
layout
==
MatrixLayout
::
MK_NK_MN_MN_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
F16
{},
F16
{},
Row
{},
Col
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
F16
{},
F16
{},
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
MatrixDataType
::
F16_F16_F16_F16_F16
&&
layout
==
MatrixLayout
::
KM_KN_MN_MN_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
F16
{},
F16
{},
Col
{},
Row
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
F16
{},
F16
{},
Col
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
MatrixDataType
::
F16_F16_F16_F16_F16
&&
layout
==
MatrixLayout
::
KM_NK_MN_MN_MN
)
{
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
F16
{},
F16
{},
Col
{},
Col
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F32
{},
F16
{},
F16
{},
F16
{},
Col
{},
Col
{},
Row
{});
}
else
{
...
...
profiler/src/profile_normalization.cpp
0 → 100644
View file @
c6891e12
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include <unordered_map>
#include "profiler/include/profile_normalization_impl.hpp"
using
ck
::
index_t
;
using
ck
::
profiler
::
NormDataType
;
using
ck
::
profiler
::
NormType
;
struct
ArgParser
{
std
::
unordered_map
<
std
::
string
,
NormType
>
norm_dict
=
{{
"layernorm"
,
NormType
::
LAYERNORM
},
{
"batchnorm"
,
NormType
::
BATCHNORM
},
{
"softmax"
,
NormType
::
SOFTMAX
}};
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int
>>
long_opts
=
{
{
"length"
,
{}},
{
"stride"
,
{}},
{
"reduce"
,
{}},
{
"alpha"
,
{}},
{
"beta"
,
{}}};
bool
parse_opt
(
int
argc
,
char
*
argv
[],
const
std
::
string
&
key
,
int
i
)
{
if
(
std
::
string
(
"--"
)
+
key
==
argv
[
i
])
{
int
pos
=
i
;
while
(
++
i
<
argc
&&
argv
[
i
][
0
]
!=
'-'
)
{}
int
end
=
i
;
for
(
int
j
=
pos
+
1
;
j
<
end
;
j
++
)
{
long_opts
[
key
].
push_back
(
std
::
stoi
(
argv
[
j
]));
}
return
true
;
}
return
false
;
}
void
operator
()(
int
argc
,
char
*
argv
[])
{
for
(
auto
&
kv
:
long_opts
)
{
for
(
int
i
=
1
;
i
<
argc
;
i
++
)
{
if
(
parse_opt
(
argc
,
argv
,
kv
.
first
,
i
))
break
;
}
}
}
};
void
print_help
()
{
std
::
cout
<<
"arg1: tensor operation (layernorm/batchnorm/softmax)
\n
"
<<
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)
\n
"
<<
"arg3: verification (0: no; 1: yes)
\n
"
<<
"arg4: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
<<
"arg5: print tensor value (0: no; 1: yes)
\n
"
<<
"arg6: time kernel (0=n0, 1=yes)
\n
"
<<
"--length: tensor extents (e.g, --length 8 4 256)
\n
"
<<
"--stride: tensor strides (e.g, --stride 1024 256 1)
\n
"
<<
"--reduce: to-reduce dimensions (e.g, --reduce 2)
\n
"
<<
"--alpha: alpha scaling value
\n
"
<<
"--beta: beta scaling value
\n
"
<<
std
::
endl
;
}
int
profile_normalization
(
int
argc
,
char
*
argv
[])
{
if
(
argc
<=
2
)
{
print_help
();
return
0
;
}
ArgParser
arg_parser
;
// short unnamed options
const
NormType
norm_type
=
arg_parser
.
norm_dict
[
argv
[
1
]];
const
NormDataType
data_type
=
static_cast
<
NormDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
3
]);
const
int
init_method
=
std
::
stoi
(
argv
[
4
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
5
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
6
]);
// parse the long options
arg_parser
(
argc
,
argv
);
const
std
::
vector
<
index_t
>
length
=
arg_parser
.
long_opts
[
"length"
];
const
std
::
vector
<
index_t
>
stride
=
arg_parser
.
long_opts
[
"stride"
];
const
std
::
vector
<
index_t
>
reduce
=
arg_parser
.
long_opts
[
"reduce"
];
const
index_t
alpha
=
arg_parser
.
long_opts
[
"alpha"
].
empty
()
?
1
:
arg_parser
.
long_opts
[
"alpha"
][
0
];
const
index_t
beta
=
arg_parser
.
long_opts
[
"beta"
].
empty
()
?
0
:
arg_parser
.
long_opts
[
"beta"
][
0
];
if
(
data_type
==
NormDataType
::
F16_F16
)
{
ck
::
profiler
::
profile_normalization_impl
<
ck
::
half_t
,
float
,
ck
::
half_t
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
length
,
stride
,
reduce
,
float
(
alpha
),
float
(
beta
),
norm_type
);
}
else
if
(
data_type
==
NormDataType
::
F32_F32
)
{
ck
::
profiler
::
profile_normalization_impl
<
float
,
float
,
float
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
length
,
stride
,
reduce
,
float
(
alpha
),
float
(
beta
),
norm_type
);
}
else
{
throw
std
::
runtime_error
(
"not implemented yet"
);
}
return
0
;
}
// hijack main() for quick debugging
// int main(int argc, char* argv[])
// {
// profile_normalization(argc, argv);
// return 0;
// }
profiler/src/profiler.cpp
View file @
c6891e12
...
...
@@ -20,6 +20,7 @@ int profile_conv_fwd_bias_relu_add(int, char*[]);
int
profile_convnd_fwd
(
int
argc
,
char
*
argv
[]);
int
profile_convnd_bwd_data
(
int
,
char
*
[],
int
);
int
profile_conv_bwd_weight
(
int
,
char
*
[]);
int
profile_normalization
(
int
,
char
*
[]);
int
profile_reduce
(
int
,
char
*
[]);
static
void
print_helper_message
()
...
...
@@ -130,6 +131,11 @@ int main(int argc, char* argv[])
{
return
profile_gemm_add_add_fastgelu
(
argc
,
argv
);
}
else
if
(
strcmp
(
argv
[
1
],
"batchnorm"
)
==
0
||
strcmp
(
argv
[
1
],
"layernorm"
)
==
0
||
strcmp
(
argv
[
1
],
"softmax"
)
==
0
)
{
return
profile_normalization
(
argc
,
argv
);
}
else
{
print_helper_message
();
...
...
script/docker-rocm4.1.sh
deleted
100755 → 0
View file @
f591ad27
WORKSPACE
=
$1
echo
"workspace: "
$WORKSPACE
docker run
\
-it
\
--rm
\
--privileged
\
--group-add
sudo
\
-w
/root/workspace
\
-v
$WORKSPACE
:/root/workspace
\
rocm/tensorflow:rocm4.1-tf1.15-dev
\
/bin/bash
#--network host \
script/docker-rocm4.3.1.sh
deleted
100755 → 0
View file @
f591ad27
WORKSPACE
=
$1
echo
"workspace: "
$WORKSPACE
docker run
\
-it
\
--rm
\
--privileged
\
--group-add
sudo
\
-w
/root/workspace
\
-v
$WORKSPACE
:/root/workspace
\
rocm/tensorflow:rocm4.3.1-tf2.6-dev
\
/bin/bash
#--network host \
test/batched_gemm/batched_gemm_fp16.cpp
View file @
c6891e12
...
...
@@ -25,19 +25,19 @@ int main()
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Row
,
Row
,
Row
>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
N
,
N
,
BatchCount
);
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Row
,
Col
,
Row
>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
K
,
N
,
BatchCount
);
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Col
,
Row
,
Row
>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
N
,
N
,
BatchCount
);
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Col
,
Col
,
Row
>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
K
,
N
,
BatchCount
);
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
std
::
cout
<<
"test BatchedGEMM fp16: "
<<
(
pass
?
"Pass"
:
"Fail"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
...
...
test/conv2d_bwd_data/conv2d_bwd_data.cpp
View file @
c6891e12
...
...
@@ -20,7 +20,7 @@ using INT8 = int8_t;
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv2d_bwd_data_
instance
{
namespace
instance
{
using
DeviceConvBwdDataNoOpPtr
=
DeviceConvBwdDataPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -36,7 +36,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
(
std
::
vector
<
DeviceConvBwdDataNoOpPtr
>&
);
}
// namespace
device_conv2d_bwd_data_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -220,28 +220,28 @@ int main(int argc, char* argv[])
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
float
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
float
>
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances
(
conv_ptrs
);
}
else
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cv_t
<
InDataType
>
,
ck
::
half_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
ck
::
half_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
ck
::
half_t
>
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances
(
conv_ptrs
);
}
else
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cv_t
<
InDataType
>
,
ck
::
bhalf_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
ck
::
bhalf_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
ck
::
bhalf_t
>
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances
(
conv_ptrs
);
}
else
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cv_t
<
InDataType
>
,
int8_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
int8_t
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
int8_t
>
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_bwd_data_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
(
conv_ptrs
);
}
...
...
test/convnd_fwd/conv_util.hpp
View file @
c6891e12
...
...
@@ -19,14 +19,14 @@ namespace device {
using
DeviceConvFwdNoOpPtr
=
DeviceConvFwdPtr
<
element_wise
::
PassThrough
,
element_wise
::
PassThrough
,
element_wise
::
PassThrough
>
;
namespace
device_conv2d_fwd_
instance
{
namespace
instance
{
void
add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances
(
std
::
vector
<
DeviceConvFwdNoOpPtr
>&
);
void
add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
DeviceConvFwdNoOpPtr
>&
);
void
add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances
(
std
::
vector
<
DeviceConvFwdNoOpPtr
>&
);
void
add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances
(
std
::
vector
<
DeviceConvFwdNoOpPtr
>&
);
}
// namespace
device_conv2d_fwd_
instance
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
...
...
@@ -118,7 +118,7 @@ struct ConvolutionNDFwdInstances<float, float, float>
std
::
vector
<
DeviceConvFwdNoOpPtr
>
conv_ptrs
;
if
(
num_dim_spatial
==
2
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances
(
conv_ptrs
);
}
return
conv_ptrs
;
...
...
@@ -133,7 +133,7 @@ struct ConvolutionNDFwdInstances<ck::half_t, ck::half_t, ck::half_t>
std
::
vector
<
DeviceConvFwdNoOpPtr
>
conv_ptrs
;
if
(
num_dim_spatial
==
2
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances
(
conv_ptrs
);
}
return
conv_ptrs
;
...
...
@@ -148,7 +148,7 @@ struct ConvolutionNDFwdInstances<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t>
std
::
vector
<
DeviceConvFwdNoOpPtr
>
conv_ptrs
;
if
(
num_dim_spatial
==
2
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances
(
conv_ptrs
);
}
return
conv_ptrs
;
...
...
@@ -163,7 +163,7 @@ struct ConvolutionNDFwdInstances<int8_t, int8_t, int8_t>
std
::
vector
<
DeviceConvFwdNoOpPtr
>
conv_ptrs
;
if
(
num_dim_spatial
==
2
)
{
ck
::
tensor_operation
::
device
::
device_conv2d_fwd_
instance
::
ck
::
tensor_operation
::
device
::
instance
::
add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances
(
conv_ptrs
);
}
return
conv_ptrs
;
...
...
test/gemm/CMakeLists.txt
View file @
c6891e12
# GEMM XDL
add_test_executable
(
test_gemm_xdl_fp32 gemm_xdl_fp32.cpp
)
target_link_libraries
(
test_gemm_xdl_fp32 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_xdl_fp32 PRIVATE device_gemm_instance
)
add_test_executable
(
test_gemm_fp32 gemm_fp32.cpp
)
target_link_libraries
(
test_gemm_fp32 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_fp32 PRIVATE device_gemm_instance
)
add_test_executable
(
test_gemm_
xdl_
fp16 gemm_
xdl_
fp16.cpp
)
target_link_libraries
(
test_gemm_
xdl_
fp16 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_
xdl_
fp16 PRIVATE device_gemm_instance
)
add_test_executable
(
test_gemm_fp16 gemm_fp16.cpp
)
target_link_libraries
(
test_gemm_fp16 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_fp16 PRIVATE device_gemm_instance
)
add_test_executable
(
test_gemm_
xdl_
bf16 gemm_
xdl_
bf16.cpp
)
target_link_libraries
(
test_gemm_
xdl_
bf16 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_
xdl_
bf16 PRIVATE device_gemm_instance
)
add_test_executable
(
test_gemm_bf16 gemm_bf16.cpp
)
target_link_libraries
(
test_gemm_bf16 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_bf16 PRIVATE device_gemm_instance
)
add_test_executable
(
test_gemm_xdl_int8 gemm_xdl_int8.cpp
)
target_link_libraries
(
test_gemm_xdl_int8 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_xdl_int8 PRIVATE device_gemm_instance
)
# GEMM DL
add_test_executable
(
test_gemm_dl_fp32 gemm_dl_fp32.cpp
)
target_link_libraries
(
test_gemm_dl_fp32 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_dl_fp32 PRIVATE device_gemm_instance
)
add_test_executable
(
test_gemm_dl_fp16 gemm_dl_fp16.cpp
)
target_link_libraries
(
test_gemm_dl_fp16 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_dl_fp16 PRIVATE device_gemm_instance
)
add_test_executable
(
test_gemm_dl_int8 gemm_dl_int8.cpp
)
target_link_libraries
(
test_gemm_dl_int8 PRIVATE host_tensor
)
TArget_link_libraries
(
test_gemm_dl_int8 PRIVATE device_gemm_instance
)
add_test_executable
(
test_gemm_int8 gemm_int8.cpp
)
target_link_libraries
(
test_gemm_int8 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_int8 PRIVATE device_gemm_instance
)
test/gemm/gemm_bf16.cpp
0 → 100644
View file @
c6891e12
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "test/gemm/gemm_util.hpp"
int
main
()
{
using
ADataType
=
ck
::
bhalf_t
;
using
BDataType
=
ck
::
bhalf_t
;
using
CDataType
=
ck
::
bhalf_t
;
using
AccDataType
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
auto
test
=
[
&
](
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
bool
pass
=
true
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemm
<
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
const
auto
gemmPtrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
pass
&=
ck
::
gemm_util
::
TestGemm
<
std
::
unique_ptr
<
DeviceOp
>
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
return
pass
;
};
bool
pass
=
test
(
Row
{},
Row
{},
Row
{})
&&
test
(
Row
{},
Col
{},
Row
{})
&&
test
(
Col
{},
Row
{},
Row
{})
&&
test
(
Col
{},
Col
{},
Row
{});
std
::
cout
<<
"TestGemm ..... "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/gemm/gemm_dl_fp16.cpp
deleted
100644 → 0
View file @
f591ad27
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "test/gemm/gemm_util.hpp"
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
void
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
int
main
()
{
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
bool
res
=
true
;
std
::
vector
<
DeviceGemmNoOpPtr
>
gemmPtrs
;
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
RowMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
RowMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
ColumnMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
std
::
cout
<<
"TestGemm ..... "
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
res
?
0
:
1
;
}
test/gemm/gemm_dl_fp32.cpp
deleted
100644 → 0
View file @
f591ad27
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "test/gemm/gemm_util.hpp"
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
void
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
int
main
()
{
using
ADataType
=
float
;
using
BDataType
=
float
;
using
CDataType
=
float
;
using
AccDataType
=
float
;
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
bool
res
=
true
;
std
::
vector
<
DeviceGemmNoOpPtr
>
gemmPtrs
;
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
RowMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
RowMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
ColumnMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
std
::
cout
<<
"TestGemm ..... "
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
res
?
0
:
1
;
}
test/gemm/gemm_dl_int8.cpp
deleted
100644 → 0
View file @
f591ad27
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "test/gemm/gemm_util.hpp"
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
void
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
int
main
()
{
using
ADataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
CDataType
=
int8_t
;
using
AccDataType
=
int
;
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
bool
res
=
true
;
std
::
vector
<
DeviceGemmNoOpPtr
>
gemmPtrs
;
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
RowMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
RowMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
ColumnMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
std
::
cout
<<
"TestGemm ..... "
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
res
?
0
:
1
;
}
test/gemm/gemm_fp16.cpp
0 → 100644
View file @
c6891e12
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "test/gemm/gemm_util.hpp"
int
main
()
{
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
auto
test
=
[
&
](
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
bool
pass
=
true
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemm
<
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
const
auto
gemmPtrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
pass
&=
ck
::
gemm_util
::
TestGemm
<
std
::
unique_ptr
<
DeviceOp
>
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
return
pass
;
};
bool
pass
=
test
(
Row
{},
Row
{},
Row
{})
&&
test
(
Row
{},
Col
{},
Row
{})
&&
test
(
Col
{},
Row
{},
Row
{})
&&
test
(
Col
{},
Col
{},
Row
{});
std
::
cout
<<
"TestGemm ..... "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
Prev
1
…
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment