Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b89a88b5
Commit
b89a88b5
authored
Sep 19, 2022
by
Adam Osewski
Browse files
Merge branch 'develop' into wavelet_model
parents
41d5fca7
43c898f6
Changes
261
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1368 additions
and
101 deletions
+1368
-101
profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp
profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp
+11
-3
profiler/include/profile_normalization_impl.hpp
profiler/include/profile_normalization_impl.hpp
+37
-16
profiler/src/profile_batched_gemm_add_relu_gemm_add.cpp
profiler/src/profile_batched_gemm_add_relu_gemm_add.cpp
+209
-0
profiler/src/profile_batched_gemm_gemm.cpp
profiler/src/profile_batched_gemm_gemm.cpp
+181
-0
profiler/src/profile_normalization.cpp
profiler/src/profile_normalization.cpp
+65
-23
profiler/src/profiler.cpp
profiler/src/profiler.cpp
+12
-0
script/clang-format-overwrite.sh
script/clang-format-overwrite.sh
+2
-2
script/process_perf_data.py
script/process_perf_data.py
+14
-1
script/process_qa_data.sh
script/process_qa_data.sh
+4
-2
script/profile_onnx_gemm.sh
script/profile_onnx_gemm.sh
+31
-0
script/profile_splitK_gemm.sh
script/profile_splitK_gemm.sh
+41
-0
script/run_full_performance_tests.sh
script/run_full_performance_tests.sh
+70
-52
test/CMakeLists.txt
test/CMakeLists.txt
+1
-0
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
+111
-1
test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp
test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp
+121
-0
test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp
...gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp
+122
-0
test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp
...gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp
+121
-0
test/data_type/CMakeLists.txt
test/data_type/CMakeLists.txt
+4
-0
test/data_type/int4.cpp
test/data_type/int4.cpp
+211
-0
test/gemm_split_k/gemm_split_k.cpp
test/gemm_split_k/gemm_split_k.cpp
+0
-1
No files found.
profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp
View file @
b89a88b5
...
...
@@ -142,13 +142,21 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
std
::
cout
<<
"b1_g_n_o: "
<<
b1_g_n_o
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_g_m_o: "
<<
c_g_m_o_host_result
.
mDesc
<<
std
::
endl
;
std
::
srand
(
1
);
// work around test flakiness
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
5
,
5
});
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
5
,
5
});
// Still unsure whether this kind of deterministic floating point accurary issue is expected
// or not. May want to try exact same approach as the GPU kernel in the host reference
// GEMM+Softmax+GEMM function to see if the accuracy discrepancy goes away. Until then,
// shrink the input value range as it is less likely to produce errors of around ~1e-3.
// a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
// b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
// b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
case
2
:
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
...
...
profiler/include/profile_normalization_impl.hpp
View file @
b89a88b5
...
...
@@ -6,25 +6,36 @@
#include <iomanip>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_softmax_f16_f16_rank3_instances
(
std
::
vector
<
DeviceNormalizationPtr
>&
);
void
add_device_softmax_f16_f16_rank4_instances
(
std
::
vector
<
DeviceNormalizationPtr
>&
);
namespace
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
}
// namespace
void
add_device_softmax_f32_f32_rank3_instances
(
std
::
vector
<
DeviceNormalizationPtr
>&
);
void
add_device_softmax_f32_f32_rank4_instances
(
std
::
vector
<
DeviceNormalizationPtr
>&
);
void
add_device_softmax_f16_f16_rank3_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
3
>>&
);
void
add_device_softmax_f16_f16_rank4_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
4
>>&
);
void
add_device_softmax_f32_f32_rank3_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
3
>>&
);
void
add_device_softmax_f32_f32_rank4_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
4
>>&
);
}
// namespace instance
}
// namespace device
...
...
@@ -57,7 +68,7 @@ template <> std::string type_to_string<int8_t>() { return "int8"; }
template
<
>
std
::
string
type_to_string
<
int32_t
>
()
{
return
"int32"
;
}
// clang-format on
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
>
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
index_t
Rank
>
void
profile_normalization_impl
(
int
do_verification
,
int
init_method
,
bool
do_log
,
...
...
@@ -69,6 +80,11 @@ void profile_normalization_impl(int do_verification,
AccDataType
beta
,
NormType
norm_type
)
{
if
(
Rank
!=
in_length
.
size
())
{
throw
std
::
runtime_error
(
"Input tensor rank is different from template argument Rank!"
);
}
Tensor
<
InDataType
>
in
=
in_strides
.
empty
()
?
Tensor
<
InDataType
>
(
in_length
)
:
Tensor
<
InDataType
>
(
in_length
,
in_strides
);
Tensor
<
OutDataType
>
out
(
in
.
mDesc
);
...
...
@@ -99,30 +115,31 @@ void profile_normalization_impl(int do_verification,
std
::
vector
<
index_t
>
i_in_lengths
(
in
.
mDesc
.
GetLengths
().
begin
(),
in
.
mDesc
.
GetLengths
().
end
());
std
::
vector
<
index_t
>
i_in_strides
(
in
.
mDesc
.
GetStrides
().
begin
(),
in
.
mDesc
.
GetStrides
().
end
());
// add device normalization instances
std
::
vector
<
tensor_operation
::
device
::
DeviceNormalizationPtr
>
instances
;
// add device softmax instances
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceOpPtr
=
tensor_operation
::
device
::
DeviceSoftmaxPtr
<
InDataType
,
AccDataType
,
OutDataType
,
PassThrough
,
PassThrough
,
Rank
>
;
std
::
vector
<
DeviceOpPtr
>
instances
;
if
(
norm_type
==
NormType
::
SOFTMAX
)
{
if
constexpr
(
is_same
<
InDataType
,
half_t
>::
value
&&
is_same
<
OutDataType
,
half_t
>::
value
&&
is_same
<
AccDataType
,
float
>::
value
)
{
if
(
in_length
.
size
()
==
3
)
if
constexpr
(
Rank
==
3
)
tensor_operation
::
device
::
instance
::
add_device_softmax_f16_f16_rank3_instances
(
instances
);
if
(
in_length
.
size
()
==
4
)
else
if
constexpr
(
Rank
==
4
)
tensor_operation
::
device
::
instance
::
add_device_softmax_f16_f16_rank4_instances
(
instances
);
}
else
if
constexpr
(
is_same
<
InDataType
,
float
>::
value
&&
is_same
<
OutDataType
,
float
>::
value
&&
is_same
<
AccDataType
,
float
>::
value
)
{
if
(
in_length
.
size
()
==
3
)
if
constexpr
(
Rank
==
3
)
tensor_operation
::
device
::
instance
::
add_device_softmax_f32_f32_rank3_instances
(
instances
);
if
(
in_length
.
size
()
==
4
)
else
if
constexpr
(
Rank
==
4
)
tensor_operation
::
device
::
instance
::
add_device_softmax_f32_f32_rank4_instances
(
instances
);
}
...
...
@@ -137,6 +154,8 @@ void profile_normalization_impl(int do_verification,
float
best_avg_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_gb_per_sec
=
0
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
for
(
auto
&
inst_ptr
:
instances
)
{
// Is this user's responsibility to check if problem mismatches kernel instance (ie. rank 3
...
...
@@ -153,7 +172,9 @@ void profile_normalization_impl(int do_verification,
&
alpha
,
&
beta
,
in_dev
.
GetDeviceBuffer
(),
out_dev
.
GetDeviceBuffer
());
out_dev
.
GetDeviceBuffer
(),
PassThrough
{},
PassThrough
{});
if
(
!
inst_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
...
...
profiler/src/profile_batched_gemm_add_relu_gemm_add.cpp
0 → 100644
View file @
b89a88b5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/include/profile_batched_gemm_add_relu_gemm_add_impl.hpp"
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
int
profile_batched_gemm_add_relu_gemm_add
(
int
argc
,
char
*
argv
[])
{
enum
struct
GemmMatrixLayout
{
MK_NK_MN_NO_MO_MO
,
// 0
MK_NK_MN_ON_MO_MO
,
// 1
};
enum
struct
GemmDataType
{
F32_F32_F32_F32_F32_F32
,
// 0
F16_F16_F16_F16_F16_F16
,
// 1
};
GemmDataType
data_type
=
GemmDataType
::
F16_F16_F16_F16_F16_F16
;
GemmMatrixLayout
layout
=
GemmMatrixLayout
::
MK_NK_MN_NO_MO_MO
;
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
do_log
=
0
;
bool
time_kernel
=
false
;
// GEMM shape
ck
::
index_t
M
=
1024
;
ck
::
index_t
N
=
1024
;
ck
::
index_t
K
=
64
;
ck
::
index_t
O
=
128
;
ck
::
index_t
BatchCount
=
4
;
ck
::
index_t
StrideA0
=
-
1
;
ck
::
index_t
StrideB0
=
-
1
;
ck
::
index_t
StrideD0
=
-
1
;
ck
::
index_t
StrideB1
=
-
1
;
ck
::
index_t
StrideD1
=
-
1
;
ck
::
index_t
StrideE1
=
-
1
;
ck
::
index_t
BatchStrideA0
=
-
1
;
ck
::
index_t
BatchStrideB0
=
-
1
;
ck
::
index_t
BatchStrideD0
=
-
1
;
ck
::
index_t
BatchStrideB1
=
-
1
;
ck
::
index_t
BatchStrideD1
=
-
1
;
ck
::
index_t
BatchStrideE1
=
-
1
;
if
(
argc
==
8
)
{
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
do_verification
=
std
::
stoi
(
argv
[
4
]);
init_method
=
std
::
stoi
(
argv
[
5
]);
do_log
=
std
::
stoi
(
argv
[
6
]);
time_kernel
=
std
::
stoi
(
argv
[
7
]);
}
else
if
(
argc
==
13
)
{
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
do_verification
=
std
::
stoi
(
argv
[
4
]);
init_method
=
std
::
stoi
(
argv
[
5
]);
do_log
=
std
::
stoi
(
argv
[
6
]);
time_kernel
=
std
::
stoi
(
argv
[
7
]);
M
=
std
::
stoi
(
argv
[
8
]);
N
=
std
::
stoi
(
argv
[
9
]);
K
=
std
::
stoi
(
argv
[
10
]);
O
=
std
::
stoi
(
argv
[
11
]);
BatchCount
=
std
::
stoi
(
argv
[
12
]);
}
else
if
(
argc
==
25
)
{
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
do_verification
=
std
::
stoi
(
argv
[
4
]);
init_method
=
std
::
stoi
(
argv
[
5
]);
do_log
=
std
::
stoi
(
argv
[
6
]);
time_kernel
=
std
::
stoi
(
argv
[
7
]);
M
=
std
::
stoi
(
argv
[
8
]);
N
=
std
::
stoi
(
argv
[
9
]);
K
=
std
::
stoi
(
argv
[
10
]);
O
=
std
::
stoi
(
argv
[
11
]);
BatchCount
=
std
::
stoi
(
argv
[
12
]);
StrideA0
=
std
::
stoi
(
argv
[
13
]);
StrideB0
=
std
::
stoi
(
argv
[
14
]);
StrideD0
=
std
::
stoi
(
argv
[
15
]);
StrideB1
=
std
::
stoi
(
argv
[
16
]);
StrideD1
=
std
::
stoi
(
argv
[
17
]);
StrideE1
=
std
::
stoi
(
argv
[
18
]);
BatchStrideA0
=
std
::
stoi
(
argv
[
19
]);
BatchStrideB0
=
std
::
stoi
(
argv
[
20
]);
BatchStrideD0
=
std
::
stoi
(
argv
[
21
]);
BatchStrideB1
=
std
::
stoi
(
argv
[
22
]);
BatchStrideD1
=
std
::
stoi
(
argv
[
23
]);
BatchStrideE1
=
std
::
stoi
(
argv
[
24
]);
}
else
{
printf
(
"arg1: tensor operation (batched_gemm_add_relu_gemm_add: "
"Batched_GEMM+Add+Relu+Gemm+Add)
\n
"
);
printf
(
"arg2: data type (1: fp16)
\n
"
);
printf
(
"arg3: matrix layout (0: Relu(A0[m, k] * B0[n, k] + D0[m, n]) * B1[n, o] + D1[m, o] "
"= E1[m, o]; 1: Relu(A0[m, k] * B0[n, k] + D0[m, n]) * B1[o, n] + D1[m, o] = "
"E1[m, o];)
\n
"
);
printf
(
"arg4: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg6: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg8 to 12: M, N, K, O, Batch
\n
"
);
printf
(
"arg13 to 18: StrideA0, StrideB0, StrideD0, StrideB1, StrideD1, StrideE1
\n
"
);
printf
(
"arg19 to 24: BatchStrideA0, BatchStrideB0, BatchStrideD0, BatchStrideB1, "
"BatchStrideD1, BatchStrideE1
\n
"
);
exit
(
1
);
}
if
(
data_type
==
GemmDataType
::
F16_F16_F16_F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN_NO_MO_MO
)
{
ck
::
profiler
::
profile_batched_gemm_add_relu_gemm_add_impl
<
Row
,
// A0Layout,
Col
,
// B0Layout,
ck
::
Tuple
<
Row
>
,
// D0sLayout,
Row
,
// B1Layout,
ck
::
Tuple
<
Row
>
,
// D1sLayout,
Row
,
// E1Layout,
F16
,
// A0DataType,
F16
,
// B0DataType,
ck
::
Tuple
<
F16
>
,
// D0DataType,
F16
,
// B1DataType,
ck
::
Tuple
<
F16
>
,
// D1sDataType
F16
>
// E1DataType,
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
M
,
N
,
K
,
O
,
BatchCount
,
StrideA0
,
StrideB0
,
StrideD0
,
StrideB1
,
StrideD1
,
StrideE1
,
BatchStrideA0
,
BatchStrideB0
,
BatchStrideD0
,
BatchStrideB1
,
BatchStrideD1
,
BatchStrideE1
);
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16_F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN_ON_MO_MO
)
{
ck
::
profiler
::
profile_batched_gemm_add_relu_gemm_add_impl
<
Row
,
// A0Layout,
Col
,
// B0Layout,
ck
::
Tuple
<
Row
>
,
// D0sLayout,
Col
,
// B1Layout,
ck
::
Tuple
<
Row
>
,
// D1sLayout,
Row
,
// E1Layout,
F16
,
// A0DataType,
F16
,
// B0DataType,
ck
::
Tuple
<
F16
>
,
// D0DataType,
F16
,
// B1DataType,
ck
::
Tuple
<
F16
>
,
// D1sDataType
F16
>
// E1DataType,
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
M
,
N
,
K
,
O
,
BatchCount
,
StrideA0
,
StrideB0
,
StrideD0
,
StrideB1
,
StrideD1
,
StrideE1
,
BatchStrideA0
,
BatchStrideB0
,
BatchStrideD0
,
BatchStrideB1
,
BatchStrideD1
,
BatchStrideE1
);
}
else
{
throw
std
::
runtime_error
(
"wrong! this data_type & layout is not implemented"
);
}
return
0
;
}
profiler/src/profile_batched_gemm_gemm.cpp
0 → 100644
View file @
b89a88b5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/include/profile_batched_gemm_gemm_impl.hpp"
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
int
profile_batched_gemm_gemm
(
int
argc
,
char
*
argv
[])
{
enum
struct
GemmMatrixLayout
{
MK_NK_NO_MO
,
// 0
MK_NK_ON_MO
,
// 0
};
enum
struct
GemmDataType
{
F32_F32_F32_F32
,
// 0
F16_F16_F16_F16
,
// 1
};
GemmDataType
data_type
=
GemmDataType
::
F16_F16_F16_F16
;
GemmMatrixLayout
layout
=
GemmMatrixLayout
::
MK_NK_NO_MO
;
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
do_log
=
0
;
bool
time_kernel
=
false
;
// GEMM shape
ck
::
index_t
M
=
1024
;
ck
::
index_t
N
=
1024
;
ck
::
index_t
K
=
64
;
ck
::
index_t
O
=
128
;
ck
::
index_t
BatchCount
=
4
;
ck
::
index_t
StrideA0
=
-
1
;
ck
::
index_t
StrideB0
=
-
1
;
ck
::
index_t
StrideB1
=
-
1
;
ck
::
index_t
StrideE1
=
-
1
;
ck
::
index_t
BatchStrideA0
=
-
1
;
ck
::
index_t
BatchStrideB0
=
-
1
;
ck
::
index_t
BatchStrideB1
=
-
1
;
ck
::
index_t
BatchStrideE1
=
-
1
;
if
(
argc
==
8
)
{
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
do_verification
=
std
::
stoi
(
argv
[
4
]);
init_method
=
std
::
stoi
(
argv
[
5
]);
do_log
=
std
::
stoi
(
argv
[
6
]);
time_kernel
=
std
::
stoi
(
argv
[
7
]);
}
else
if
(
argc
==
13
)
{
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
do_verification
=
std
::
stoi
(
argv
[
4
]);
init_method
=
std
::
stoi
(
argv
[
5
]);
do_log
=
std
::
stoi
(
argv
[
6
]);
time_kernel
=
std
::
stoi
(
argv
[
7
]);
M
=
std
::
stoi
(
argv
[
8
]);
N
=
std
::
stoi
(
argv
[
9
]);
K
=
std
::
stoi
(
argv
[
10
]);
O
=
std
::
stoi
(
argv
[
11
]);
BatchCount
=
std
::
stoi
(
argv
[
12
]);
}
else
if
(
argc
==
21
)
{
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
do_verification
=
std
::
stoi
(
argv
[
4
]);
init_method
=
std
::
stoi
(
argv
[
5
]);
do_log
=
std
::
stoi
(
argv
[
6
]);
time_kernel
=
std
::
stoi
(
argv
[
7
]);
M
=
std
::
stoi
(
argv
[
8
]);
N
=
std
::
stoi
(
argv
[
9
]);
K
=
std
::
stoi
(
argv
[
10
]);
O
=
std
::
stoi
(
argv
[
11
]);
BatchCount
=
std
::
stoi
(
argv
[
12
]);
StrideA0
=
std
::
stoi
(
argv
[
13
]);
StrideB0
=
std
::
stoi
(
argv
[
14
]);
StrideB1
=
std
::
stoi
(
argv
[
15
]);
StrideE1
=
std
::
stoi
(
argv
[
16
]);
BatchStrideA0
=
std
::
stoi
(
argv
[
17
]);
BatchStrideB0
=
std
::
stoi
(
argv
[
18
]);
BatchStrideB1
=
std
::
stoi
(
argv
[
19
]);
BatchStrideE1
=
std
::
stoi
(
argv
[
20
]);
}
else
{
printf
(
"arg1: tensor operation (batched_gemm_gemm: Batched_GEMM+Gemm)
\n
"
);
printf
(
"arg2: data type (1: fp16)
\n
"
);
printf
(
"arg3: matrix layout (0: Relu(A0[m, k] * B0[n, k] + D0[m, n]) * B1[n, o] + D1[m, o] "
"= E1[m, o]; 1: Relu(A0[m, k] * B0[n, k] + D0[m, n]) * B1[o, n] + D1[m, o] = E1[m, "
"o];)
\n
"
);
printf
(
"arg4: verification (0: no; 1: yes)
\n
"
);
printf
(
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
);
printf
(
"arg6: print tensor value (0: no; 1: yes)
\n
"
);
printf
(
"arg7: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg8 to 12: M, N, K, O, Batch
\n
"
);
printf
(
"arg13 to 16: StrideA0, StrideB0, StrideB1, StrideE1
\n
"
);
printf
(
"arg17 to 20: BatchStrideA0, BatchStrideB0, BatchStrideB1, BatchStrideE1
\n
"
);
exit
(
1
);
}
if
(
data_type
==
GemmDataType
::
F16_F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_NO_MO
)
{
ck
::
profiler
::
profile_batched_gemm_gemm_impl
<
F16
,
// A0DataType,
F16
,
// B0DataType,
F16
,
// B1DataType,
F16
,
// E1DataType,
Row
,
// A0Layout,
Col
,
// B0Layout,
Row
,
// B1Layout,
Row
>
// E1Layout,
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
M
,
N
,
K
,
O
,
BatchCount
,
StrideA0
,
StrideB0
,
StrideB1
,
StrideE1
,
BatchStrideA0
,
BatchStrideB0
,
BatchStrideB1
,
BatchStrideE1
);
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_ON_MO
)
{
ck
::
profiler
::
profile_batched_gemm_gemm_impl
<
F16
,
// A0DataType,
F16
,
// B0DataType,
F16
,
// B1DataType,
F16
,
// E1DataType,
Row
,
// A0Layout,
Col
,
// B0Layout,
Col
,
// B1Layout,
Row
>
// E1Layout,
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
M
,
N
,
K
,
O
,
BatchCount
,
StrideA0
,
StrideB0
,
StrideB1
,
StrideE1
,
BatchStrideA0
,
BatchStrideB0
,
BatchStrideB1
,
BatchStrideE1
);
}
else
{
throw
std
::
runtime_error
(
"wrong! this data_type & layout is not implemented"
);
}
return
0
;
}
profiler/src/profile_normalization.cpp
View file @
b89a88b5
...
...
@@ -50,7 +50,7 @@ struct ArgParser
void
print_help
()
{
std
::
cout
<<
"arg1: tensor operation (
layernorm/
batchnorm/softmax)
\n
"
std
::
cout
<<
"arg1: tensor operation (batchnorm/softmax)
\n
"
<<
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)
\n
"
<<
"arg3: verification (0: no; 1: yes)
\n
"
<<
"arg4: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
...
...
@@ -91,31 +91,73 @@ int profile_normalization(int argc, char* argv[])
arg_parser
.
long_opts
[
"alpha"
].
empty
()
?
1
:
arg_parser
.
long_opts
[
"alpha"
][
0
];
const
index_t
beta
=
arg_parser
.
long_opts
[
"beta"
].
empty
()
?
0
:
arg_parser
.
long_opts
[
"beta"
][
0
];
if
(
data_type
==
NormDataType
::
F16_F16
)
if
(
length
.
size
()
==
3
)
{
ck
::
profiler
::
profile_normalization_impl
<
ck
::
half_t
,
float
,
ck
::
half_t
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
length
,
stride
,
reduce
,
float
(
alpha
),
float
(
beta
),
norm_type
);
if
(
data_type
==
NormDataType
::
F16_F16
)
{
ck
::
profiler
::
profile_normalization_impl
<
ck
::
half_t
,
float
,
ck
::
half_t
,
3
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
length
,
stride
,
reduce
,
float
(
alpha
),
float
(
beta
),
norm_type
);
}
else
if
(
data_type
==
NormDataType
::
F32_F32
)
{
ck
::
profiler
::
profile_normalization_impl
<
float
,
float
,
float
,
3
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
length
,
stride
,
reduce
,
float
(
alpha
),
float
(
beta
),
norm_type
);
}
else
{
throw
std
::
runtime_error
(
"not implemented yet"
);
}
}
else
if
(
data_type
==
NormDataType
::
F32_F32
)
else
if
(
length
.
size
()
==
4
)
{
ck
::
profiler
::
profile_normalization_impl
<
float
,
float
,
float
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
length
,
stride
,
reduce
,
float
(
alpha
),
float
(
beta
),
norm_type
);
if
(
data_type
==
NormDataType
::
F16_F16
)
{
ck
::
profiler
::
profile_normalization_impl
<
ck
::
half_t
,
float
,
ck
::
half_t
,
4
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
length
,
stride
,
reduce
,
float
(
alpha
),
float
(
beta
),
norm_type
);
}
else
if
(
data_type
==
NormDataType
::
F32_F32
)
{
ck
::
profiler
::
profile_normalization_impl
<
float
,
float
,
float
,
4
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
length
,
stride
,
reduce
,
float
(
alpha
),
float
(
beta
),
norm_type
);
}
else
{
throw
std
::
runtime_error
(
"not implemented yet"
);
}
}
else
{
...
...
profiler/src/profiler.cpp
View file @
b89a88b5
...
...
@@ -10,6 +10,8 @@ int profile_gemm_add_add_fastgelu(int, char*[]);
int
profile_gemm_reduce
(
int
,
char
*
[]);
int
profile_gemm_bias_add_reduce
(
int
,
char
*
[]);
int
profile_batched_gemm
(
int
,
char
*
[]);
int
profile_batched_gemm_gemm
(
int
,
char
*
[]);
int
profile_batched_gemm_add_relu_gemm_add
(
int
,
char
*
[]);
int
profile_batched_gemm_reduce
(
int
,
char
*
[]);
int
profile_grouped_gemm
(
int
,
char
*
[]);
int
profile_conv_fwd
(
int
,
char
*
[]);
...
...
@@ -32,6 +34,8 @@ static void print_helper_message()
" gemm_reduce: GEMM+Reduce
\n
"
" gemm_bias_add_reduce: GEMM+Bias+Add+Reduce
\n
"
" batched_gemm: Batched GEMM
\n
"
" batched_gemm_gemm: Batched+GEMM+GEMM
\n
"
" batched_gemm_add_relu_gemm_add: Batched+GEMM+bias+gelu+GEMM+bias
\n
"
" batched_gemm_reduce: Batched GEMM+Reduce
\n
"
" grouped_gemm: Grouped GEMM
\n
"
" conv_fwd: Convolution Forward
\n
"
...
...
@@ -80,6 +84,14 @@ int main(int argc, char* argv[])
{
return
profile_batched_gemm
(
argc
,
argv
);
}
else
if
(
strcmp
(
argv
[
1
],
"batched_gemm_gemm"
)
==
0
)
{
return
profile_batched_gemm_gemm
(
argc
,
argv
);
}
else
if
(
strcmp
(
argv
[
1
],
"batched_gemm_add_relu_gemm_add"
)
==
0
)
{
return
profile_batched_gemm_add_relu_gemm_add
(
argc
,
argv
);
}
else
if
(
strcmp
(
argv
[
1
],
"batched_gemm_reduce"
)
==
0
)
{
return
profile_batched_gemm_reduce
(
argc
,
argv
);
...
...
script/clang-format-overwrite.sh
View file @
b89a88b5
#find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}'
git status
--porcelain
|
awk
'$1 != "D" && (match($2, "\\.cpp|hpp")) {print $2}'
| xargs
-n
1
-P
16
-I
{}
-t
sh
-c
'clang-format-10 -i -style=file {}'
#find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu'
-o -iname '*.inc'
| xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}'
git status
--porcelain
|
awk
'$1 != "D" && (match($2, "\\.cpp|hpp
|inc
")) {print $2}'
| xargs
-n
1
-P
16
-I
{}
-t
sh
-c
'clang-format-10 -i -style=file {}'
script/process_perf_data.py
View file @
b89a88b5
...
...
@@ -127,11 +127,16 @@ def parse_logfile(logfile):
lst
=
line
.
split
()
res
.
append
(
lst
[
1
])
#parse all other performance tests:
elif
'resnet50'
or
'batched_gemm'
or
'grouped_gemm'
or
'conv_bwd_data'
or
'gemm_bilinear'
or
'reduction'
in
logfile
:
elif
'resnet50'
in
logfile
or
'batched_gemm'
in
logfile
or
'grouped_gemm'
in
logfile
or
'conv_bwd_data'
in
logfile
or
'gemm_bilinear'
in
logfile
or
'reduction'
in
logfile
:
for
line
in
open
(
logfile
):
if
'Best Perf'
in
line
:
lst
=
line
.
split
()
res
.
append
(
lst
[
4
])
elif
'onnx_gemm'
in
logfile
or
'splitK_gemm'
in
logfile
:
for
line
in
open
(
logfile
):
if
'Best Perf'
in
line
:
lst
=
line
.
split
()
res
.
append
(
lst
[
33
])
return
res
...
...
@@ -281,6 +286,14 @@ def main():
for
i
in
range
(
1
,
50
):
testlist
.
append
(
"Layer%i"
%
i
)
table_name
=
"ck_resnet50_N256_tflops"
if
'onnx_gemm'
in
filename
:
for
i
in
range
(
1
,
len
(
results
)
+
1
):
testlist
.
append
(
"Test%i"
%
i
)
table_name
=
"ck_onnx_gemm_tflops"
if
'splitK_gemm'
in
filename
:
for
i
in
range
(
1
,
len
(
results
)
+
1
):
testlist
.
append
(
"Test%i"
%
i
)
table_name
=
"ck_splitK_gemm_tflops"
tflops_base
=
get_baseline
(
table_name
,
conn
)
store_new_test_result
(
table_name
,
results
,
testlist
,
branch_name
,
node_id
,
gpu_arch
,
compute_units
,
rocm_vers
,
hip_vers
,
environment
,
conn
)
...
...
script/process_qa_data.sh
View file @
b89a88b5
...
...
@@ -2,8 +2,8 @@
#
# in order to run this script you'd need the following python packages:
pip3
install
--upgrade
pip
pip3
install
sqlalchemy pymysql pandas sshtunnel
#
pip3 install --upgrade pip
#
pip3 install sqlalchemy pymysql pandas sshtunnel
# you would also need to set up some environment variables in order to
# post your new test results to the database and compare them to the baseline
...
...
@@ -20,3 +20,5 @@ python3 process_perf_data.py perf_conv_fwd_"$gpu_arch".log
python3 process_perf_data.py perf_conv_bwd_data_
"
$gpu_arch
"
.log
python3 process_perf_data.py perf_gemm_bilinear_
"
$gpu_arch
"
.log
python3 process_perf_data.py perf_reduction_
"
$gpu_arch
"
.log
python3 process_perf_data.py perf_splitK_gemm_
"
$gpu_arch
"
.log
python3 process_perf_data.py perf_onnx_gemm_
"
$gpu_arch
"
.log
script/profile_onnx_gemm.sh
0 → 100755
View file @
b89a88b5
#!/bin/bash
## GPU visibility
export
HIP_VISIBLE_DEVICES
=
0
DRIVER
=
"../build/bin/ckProfiler"
echo
$DRIVER
OP
=
$1
DATATYPE
=
$2
LAYOUT
=
$3
VERIFY
=
$4
INIT
=
$5
LOG
=
$6
TIME
=
$7
# GEMM kernel benchmarks used by ONNX
######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideC
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
384 768 768
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
384 768 2304
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
384 768 3072
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
384 3072 768
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
384 1024 1024
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
384 1024 3072
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
384 1024 4096
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
384 4096 1024
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
24576 768 768
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
24576 768 2304
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
24576 768 3072
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
24576 3072 768
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
24576 1024 1024
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
24576 1024 3072
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
24576 1024 4096
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
24576 4096 1024
-1
-1
-1
script/profile_splitK_gemm.sh
0 → 100755
View file @
b89a88b5
#!/bin/bash
## GPU visibility
export
HIP_VISIBLE_DEVICES
=
0
DRIVER
=
"../build/bin/ckProfiler"
echo
$DRIVER
OP
=
$1
DATATYPE
=
$2
LAYOUT
=
$3
VERIFY
=
$4
INIT
=
$5
LOG
=
$6
TIME
=
$7
KBatch
=
$8
# 120 CU
######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideC KBatch_
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
960 1024 1024
-1
-1
-1
$KBatch
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
960 2048 2048
-1
-1
-1
$KBatch
# 104 CU
######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideC KBatch_
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
832 1024 1024
-1
-1
-1
$KBatch
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
832 2048 2048
-1
-1
-1
$KBatch
# 110 CU
######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideC KBatch_
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
1280 1408 1024
-1
-1
-1
$KBatch
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
1280 2816 2048
-1
-1
-1
$KBatch
# testing different strides
######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideC KBatch_
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
1024 1024 1024 1024 1024 1024
$KBatch
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
2048 2048 2048 2048 2048 2048
$KBatch
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
1024 1024 1024 1056 1056 1056
$KBatch
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
2048 2048 2048 2080 2080 2080
$KBatch
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
1024 1024 1024 1088 1088 1088
$KBatch
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
2048 2048 2048 2112 2112 2112
$KBatch
script/run_full_performance_tests.sh
View file @
b89a88b5
...
...
@@ -40,85 +40,103 @@ function print_log_header(){
#run gemm tests
export
gemm_log
=
"perf_gemm_
${
gpu_arch
}
.log"
print_log_header
$gemm_log
$env_type
$branch
$host_name
./profile_gemm.sh gemm 0 0
$verify
1 0 1 |
tee
-a
$gemm_log
./profile_gemm.sh gemm 1 0
$verify
1 0 1 |
tee
-a
$gemm_log
./profile_gemm.sh gemm 2 0
$verify
1 0 1 |
tee
-a
$gemm_log
./profile_gemm.sh gemm 3 0
$verify
1 0 1 |
tee
-a
$gemm_log
./profile_gemm.sh gemm 0 1
$verify
1 0 1 |
tee
-a
$gemm_log
./profile_gemm.sh gemm 1 1
$verify
1 0 1 |
tee
-a
$gemm_log
./profile_gemm.sh gemm 2 1
$verify
1 0 1 |
tee
-a
$gemm_log
./profile_gemm.sh gemm 3 1
$verify
1 0 1 |
tee
-a
$gemm_log
./profile_gemm.sh gemm 0 2
$verify
1 0 1 |
tee
-a
$gemm_log
./profile_gemm.sh gemm 1 2
$verify
1 0 1 |
tee
-a
$gemm_log
./profile_gemm.sh gemm 2 2
$verify
1 0 1 |
tee
-a
$gemm_log
./profile_gemm.sh gemm 3 2
$verify
1 0 1 |
tee
-a
$gemm_log
./profile_gemm.sh gemm 0 3
$verify
1 0 1 |
tee
-a
$gemm_log
./profile_gemm.sh gemm 1 3
$verify
1 0 1 |
tee
-a
$gemm_log
./profile_gemm.sh gemm 2 3
$verify
1 0 1 |
tee
-a
$gemm_log
./profile_gemm.sh gemm 3 3
$verify
1 0 1 |
tee
-a
$gemm_log
./profile_gemm.sh gemm 0 0
$verify
1 0 1
2>&1
|
tee
-a
$gemm_log
./profile_gemm.sh gemm 1 0
$verify
1 0 1
2>&1
|
tee
-a
$gemm_log
./profile_gemm.sh gemm 2 0
$verify
1 0 1
2>&1
|
tee
-a
$gemm_log
./profile_gemm.sh gemm 3 0
$verify
1 0 1
2>&1
|
tee
-a
$gemm_log
./profile_gemm.sh gemm 0 1
$verify
1 0 1
2>&1
|
tee
-a
$gemm_log
./profile_gemm.sh gemm 1 1
$verify
1 0 1
2>&1
|
tee
-a
$gemm_log
./profile_gemm.sh gemm 2 1
$verify
1 0 1
2>&1
|
tee
-a
$gemm_log
./profile_gemm.sh gemm 3 1
$verify
1 0 1
2>&1
|
tee
-a
$gemm_log
./profile_gemm.sh gemm 0 2
$verify
1 0 1
2>&1
|
tee
-a
$gemm_log
./profile_gemm.sh gemm 1 2
$verify
1 0 1
2>&1
|
tee
-a
$gemm_log
./profile_gemm.sh gemm 2 2
$verify
1 0 1
2>&1
|
tee
-a
$gemm_log
./profile_gemm.sh gemm 3 2
$verify
1 0 1
2>&1
|
tee
-a
$gemm_log
./profile_gemm.sh gemm 0 3
$verify
1 0 1
2>&1
|
tee
-a
$gemm_log
./profile_gemm.sh gemm 1 3
$verify
1 0 1
2>&1
|
tee
-a
$gemm_log
./profile_gemm.sh gemm 2 3
$verify
1 0 1
2>&1
|
tee
-a
$gemm_log
./profile_gemm.sh gemm 3 3
$verify
1 0 1
2>&1
|
tee
-a
$gemm_log
#run batched_gemm tests
export
batched_gemm_log
=
"perf_batched_gemm_
${
gpu_arch
}
.log"
print_log_header
$batched_gemm_log
$env_type
$branch
$host_name
./profile_batched_gemm.sh batched_gemm 0 0
$verify
1 0 1 |
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 0 1
$verify
1 0 1 |
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 0 2
$verify
1 0 1 |
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 0 3
$verify
1 0 1 |
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 1 0
$verify
1 0 1 |
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 1 1
$verify
1 0 1 |
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 1 2
$verify
1 0 1 |
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 1 3
$verify
1 0 1 |
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 2 0
$verify
1 0 1 |
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 2 1
$verify
1 0 1 |
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 2 2
$verify
1 0 1 |
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 2 3
$verify
1 0 1 |
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 3 0
$verify
1 0 1 |
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 3 1
$verify
1 0 1 |
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 3 2
$verify
1 0 1 |
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 3 3
$verify
1 0 1 |
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 0 0
$verify
1 0 1
2>&1
|
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 0 1
$verify
1 0 1
2>&1
|
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 0 2
$verify
1 0 1
2>&1
|
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 0 3
$verify
1 0 1
2>&1
|
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 1 0
$verify
1 0 1
2>&1
|
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 1 1
$verify
1 0 1
2>&1
|
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 1 2
$verify
1 0 1
2>&1
|
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 1 3
$verify
1 0 1
2>&1
|
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 2 0
$verify
1 0 1
2>&1
|
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 2 1
$verify
1 0 1
2>&1
|
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 2 2
$verify
1 0 1
2>&1
|
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 2 3
$verify
1 0 1
2>&1
|
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 3 0
$verify
1 0 1
2>&1
|
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 3 1
$verify
1 0 1
2>&1
|
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 3 2
$verify
1 0 1
2>&1
|
tee
-a
$batched_gemm_log
./profile_batched_gemm.sh batched_gemm 3 3
$verify
1 0 1
2>&1
|
tee
-a
$batched_gemm_log
#run grouped_gemm tests
export
grouped_gemm_log
=
"perf_grouped_gemm_
${
gpu_arch
}
.log"
print_log_header
$grouped_gemm_log
$env_type
$branch
$host_name
./profile_grouped_gemm.sh grouped_gemm 1 0
$verify
1 0 1 |
tee
-a
$grouped_gemm_log
./profile_grouped_gemm.sh grouped_gemm 1 1
$verify
1 0 1 |
tee
-a
$grouped_gemm_log
./profile_grouped_gemm.sh grouped_gemm 1 2
$verify
1 0 1 |
tee
-a
$grouped_gemm_log
./profile_grouped_gemm.sh grouped_gemm 1 3
$verify
1 0 1 |
tee
-a
$grouped_gemm_log
./profile_grouped_gemm.sh grouped_gemm 1 0
$verify
1 0 1
2>&1
|
tee
-a
$grouped_gemm_log
./profile_grouped_gemm.sh grouped_gemm 1 1
$verify
1 0 1
2>&1
|
tee
-a
$grouped_gemm_log
./profile_grouped_gemm.sh grouped_gemm 1 2
$verify
1 0 1
2>&1
|
tee
-a
$grouped_gemm_log
./profile_grouped_gemm.sh grouped_gemm 1 3
$verify
1 0 1
2>&1
|
tee
-a
$grouped_gemm_log
#run GEMM+Bilinear tests
export
gemm_bilinear_log
=
"perf_gemm_bilinear_
${
gpu_arch
}
.log"
print_log_header
$gemm_bilinear_log
$env_type
$branch
$host_name
./profile_gemm_bilinear.sh gemm_bilinear 1 0
$verify
1 0 1 |
tee
-a
$gemm_bilinear_log
./profile_gemm_bilinear.sh gemm_bilinear 1 1
$verify
1 0 1 |
tee
-a
$gemm_bilinear_log
./profile_gemm_bilinear.sh gemm_bilinear 1 2
$verify
1 0 1 |
tee
-a
$gemm_bilinear_log
./profile_gemm_bilinear.sh gemm_bilinear 1 3
$verify
1 0 1 |
tee
-a
$gemm_bilinear_log
./profile_gemm_bilinear.sh gemm_bilinear 1 0
$verify
1 0 1
2>&1
|
tee
-a
$gemm_bilinear_log
./profile_gemm_bilinear.sh gemm_bilinear 1 1
$verify
1 0 1
2>&1
|
tee
-a
$gemm_bilinear_log
./profile_gemm_bilinear.sh gemm_bilinear 1 2
$verify
1 0 1
2>&1
|
tee
-a
$gemm_bilinear_log
./profile_gemm_bilinear.sh gemm_bilinear 1 3
$verify
1 0 1
2>&1
|
tee
-a
$gemm_bilinear_log
#run conv_fwd tests
export
conv_fwd_log
=
"perf_conv_fwd_
${
gpu_arch
}
.log"
print_log_header
$conv_fwd_log
$env_type
$branch
$host_name
./profile_conv_fwd.sh conv_fwd 0 1
$verify
1 0 1 256 |
tee
-a
$conv_fwd_log
./profile_conv_fwd.sh conv_fwd 1 1
$verify
1 0 1 256 |
tee
-a
$conv_fwd_log
./profile_conv_fwd.sh conv_fwd 2 1
$verify
1 0 1 256 |
tee
-a
$conv_fwd_log
./profile_conv_fwd.sh conv_fwd 3 1
$verify
1 0 1 256 |
tee
-a
$conv_fwd_log
./profile_conv_fwd.sh conv_fwd 0 1
$verify
1 0 1 256
2>&1
|
tee
-a
$conv_fwd_log
./profile_conv_fwd.sh conv_fwd 1 1
$verify
1 0 1 256
2>&1
|
tee
-a
$conv_fwd_log
./profile_conv_fwd.sh conv_fwd 2 1
$verify
1 0 1 256
2>&1
|
tee
-a
$conv_fwd_log
./profile_conv_fwd.sh conv_fwd 3 1
$verify
1 0 1 256
2>&1
|
tee
-a
$conv_fwd_log
#run conv_bwd_data tests
export
conv_bwd_data_log
=
"perf_conv_bwd_data_
${
gpu_arch
}
.log"
print_log_header
$conv_bwd_data_log
$env_type
$branch
$host_name
./profile_conv_bwd_data.sh conv_bwd_data 0 1
$verify
1 0 1 256 |
tee
-a
$conv_bwd_data_log
./profile_conv_bwd_data.sh conv_bwd_data 1 1
$verify
1 0 1 256 |
tee
-a
$conv_bwd_data_log
./profile_conv_bwd_data.sh conv_bwd_data 2 1
$verify
1 0 1 256 |
tee
-a
$conv_bwd_data_log
./profile_conv_bwd_data.sh conv_bwd_data 3 1
$verify
1 0 1 256 |
tee
-a
$conv_bwd_data_log
./profile_conv_bwd_data.sh conv_bwd_data 0 1
$verify
1 0 1 256
2>&1
|
tee
-a
$conv_bwd_data_log
./profile_conv_bwd_data.sh conv_bwd_data 1 1
$verify
1 0 1 256
2>&1
|
tee
-a
$conv_bwd_data_log
./profile_conv_bwd_data.sh conv_bwd_data 2 1
$verify
1 0 1 256
2>&1
|
tee
-a
$conv_bwd_data_log
./profile_conv_bwd_data.sh conv_bwd_data 3 1
$verify
1 0 1 256
2>&1
|
tee
-a
$conv_bwd_data_log
#run resnet50 tests
export
resnet256_log
=
"perf_resnet50_N256_
${
gpu_arch
}
.log"
print_log_header
$resnet256_log
$env_type
$branch
$host_name
./profile_resnet50.sh conv_fwd_bias_relu 1 1 1 1
$verify
1 0 1 256 |
tee
-a
$resnet256_log
./profile_resnet50.sh conv_fwd_bias_relu 1 1 1 1
$verify
1 0 1 256
2>&1
|
tee
-a
$resnet256_log
export
resnet4_log
=
"perf_resnet50_N4_
${
gpu_arch
}
.log"
print_log_header
$resnet4_log
$env_type
$branch
$host_name
./profile_resnet50.sh conv_fwd_bias_relu 1 1 1 1
$verify
1 0 1 4 |
tee
-a
$resnet4_log
./profile_resnet50.sh conv_fwd_bias_relu 1 1 1 1
$verify
1 0 1 4
2>&1
|
tee
-a
$resnet4_log
#run reduction tests
export
reduction_log
=
"perf_reduction_
${
gpu_arch
}
.log"
print_log_header
$reduction_log
$env_type
$branch
$host_name
./profile_reduce_with_index.sh
$verify
2 10
--half
|
tee
-a
$reduction_log
./profile_reduce_no_index.sh
$verify
2 10
--half
|
tee
-a
$reduction_log
./profile_reduce_with_index.sh
$verify
2 10
--half
2>&1 |
tee
-a
$reduction_log
./profile_reduce_no_index.sh
$verify
2 10
--half
2>&1 |
tee
-a
$reduction_log
#run splitK_gemm tests
export
splitK_gemm_log
=
"perf_splitK_gemm_
${
gpu_arch
}
.log"
print_log_header
$splitK_gemm_log
$env_type
$branch
$host_name
./profile_splitK_gemm.sh gemm_splitk 0 0
$verify
1 0 1 4 2>&1 |
tee
-a
$splitK_gemm_log
./profile_splitK_gemm.sh gemm_splitk 0 1
$verify
1 0 1 4 2>&1 |
tee
-a
$splitK_gemm_log
./profile_splitK_gemm.sh gemm_splitk 0 2
$verify
1 0 1 4 2>&1 |
tee
-a
$splitK_gemm_log
./profile_splitK_gemm.sh gemm_splitk 0 3
$verify
1 0 1 4 2>&1 |
tee
-a
$splitK_gemm_log
./profile_splitK_gemm.sh gemm_splitk 1 0
$verify
1 0 1 4 2>&1 |
tee
-a
$splitK_gemm_log
./profile_splitK_gemm.sh gemm_splitk 1 1
$verify
1 0 1 4 2>&1 |
tee
-a
$splitK_gemm_log
./profile_splitK_gemm.sh gemm_splitk 1 2
$verify
1 0 1 4 2>&1 |
tee
-a
$splitK_gemm_log
./profile_splitK_gemm.sh gemm_splitk 1 3
$verify
1 0 1 4 2>&1 |
tee
-a
$splitK_gemm_log
#run ONNX gemm tests
export
onnx_log
=
"perf_onnx_gemm_
${
gpu_arch
}
.log"
print_log_header
$onnx_log
$env_type
$branch
$host_name
./profile_onnx_gemm.sh gemm 0 0
$verify
1 0 1 2>&1 |
tee
-a
$onnx_log
./profile_onnx_gemm.sh gemm 1 0
$verify
1 0 1 2>&1 |
tee
-a
$onnx_log
test/CMakeLists.txt
View file @
b89a88b5
...
...
@@ -51,3 +51,4 @@ add_subdirectory(grouped_convnd_fwd)
add_subdirectory
(
block_to_ctile_map
)
add_subdirectory
(
softmax
)
add_subdirectory
(
layernorm
)
add_subdirectory
(
data_type
)
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
View file @
b89a88b5
...
...
@@ -11,7 +11,8 @@ class TestBatchedGemmGemmFP16 : public TestBatchedGemmGemm<Tuple>
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
F16
,
F16
,
F16
,
F16
,
Row
,
Col
,
Row
,
Row
>
std
::
tuple
<
F16
,
F16
,
F16
,
F16
,
Row
,
Col
,
Row
,
Row
>
,
std
::
tuple
<
F16
,
F16
,
F16
,
F16
,
Row
,
Col
,
Col
,
Row
>
>
;
// clang-format on
...
...
@@ -19,6 +20,73 @@ TYPED_TEST_SUITE(TestBatchedGemmGemmFP16, KernelTypes);
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16
)
{
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_PadM
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
136
,
128
,
32
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_PadN
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
136
,
32
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_PadK
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
40
,
128
,
1
},
{
128
,
128
,
136
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_PadO
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
32
,
136
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_OddM
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
129
,
128
,
32
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_OddN
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
129
,
32
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_OddK
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
33
,
128
,
1
},
{
128
,
128
,
129
,
128
,
1
},
};
this
->
Run
();
}
// If kernel B1Layout is RowMajor, expect not to support odd O size
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_OddO
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
32
,
129
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
DISABLED_Bench_FP16
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
...
...
@@ -37,3 +105,45 @@ TYPED_TEST(TestBatchedGemmGemmFP16, DISABLED_Bench_FP16)
this
->
verify_
=
false
;
this
->
Run
();
}
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
TEST
(
TestBatchedGemmGemmInterface
,
GemmSpecializationSizeMatch
)
{
int
P
=
120
;
// requires padding
int
Q
=
128
;
// do not require padding
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
Default
>
{}.
IsSupported
(
Q
,
Q
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MPadding
>
{}.
IsSupported
(
P
,
Q
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NPadding
>
{}.
IsSupported
(
Q
,
P
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
KPadding
>
{}.
IsSupported
(
Q
,
Q
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNPadding
>
{}.
IsSupported
(
P
,
P
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MKPadding
>
{}.
IsSupported
(
P
,
Q
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NKPadding
>
{}.
IsSupported
(
Q
,
P
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
P
,
P
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
OPadding
>
{}.
IsSupported
(
Q
,
Q
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MOPadding
>
{}.
IsSupported
(
P
,
Q
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NOPadding
>
{}.
IsSupported
(
Q
,
P
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
KOPadding
>
{}.
IsSupported
(
Q
,
Q
,
P
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNOPadding
>
{}.
IsSupported
(
P
,
P
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MKOPadding
>
{}.
IsSupported
(
P
,
Q
,
P
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NKOPadding
>
{}.
IsSupported
(
Q
,
P
,
P
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
P
,
P
,
P
,
P
));
// clang-format on
}
TEST
(
TestBatchedGemmGemmInterface
,
GemmSpecializationSizeMismatch
)
{
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
Default
>
{}.
IsSupported
(
128
,
128
,
120
,
128
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
128
,
128
,
128
,
120
));
// Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
128
,
128
,
129
,
128
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
128
,
128
,
130
,
128
));
// Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
128
,
128
,
128
,
129
));
// clang-format on
}
test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp
View file @
b89a88b5
...
...
@@ -4,8 +4,12 @@
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "profiler/include/profile_batched_gemm_gemm_impl.hpp"
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
template
<
ck
::
index_t
N
>
using
I
=
ck
::
Number
<
N
>
;
...
...
@@ -66,3 +70,120 @@ struct TestBatchedGemmGemm : public ::testing::Test
}
}
};
template
<
GemmSpecialization
GemmSpec
>
struct
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ALayout
=
Row
;
using
B0Layout
=
Col
;
using
B1Layout
=
Row
;
using
CLayout
=
Row
;
using
ADataType
=
F16
;
using
B0DataType
=
F16
;
using
B1DataType
=
F16
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
using
CDataType
=
F16
;
using
AElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
PassThrough
;
using
B1ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
using
DeviceGemmGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmGemm_Xdl_CShuffle
<
ALayout
,
B0Layout
,
B1Layout
,
CLayout
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
128
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
>
;
// CShuffleBlockTransferScalarPerVector_NPerBlock
bool
IsSupported
(
int
M
,
int
N
,
int
K
,
int
O
)
{
auto
gemm
=
DeviceGemmGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
nullptr
),
static_cast
<
B0DataType
*>
(
nullptr
),
static_cast
<
B1DataType
*>
(
nullptr
),
static_cast
<
CDataType
*>
(
nullptr
),
M
,
N
,
K
,
O
,
0
,
// BatchCount
0
,
// StrideA
0
,
// StrideB0
0
,
// StrideB1
0
,
// StrideC
0
,
// BatchStrideA
0
,
// BatchStrideB0
0
,
// BatchStrideB1
0
,
// BatchStrideC
PassThrough
{},
// a_element_op
PassThrough
{},
// b0_element_op
PassThrough
{},
// acc0_element_op
PassThrough
{},
// b1_element_op
PassThrough
{});
// c_element_op
return
gemm
.
IsSupportedArgument
(
argument
);
}
};
test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp
View file @
b89a88b5
...
...
@@ -19,6 +19,73 @@ TYPED_TEST_SUITE(TestBatchedGemmSoftmaxGemmFP16, KernelTypes);
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16
)
{
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_PadM
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
136
,
128
,
32
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_PadN
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
136
,
32
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_PadK
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
40
,
128
,
1
},
{
128
,
128
,
136
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_PadO
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
32
,
136
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_OddM
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
129
,
128
,
32
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_OddN
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
129
,
32
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_OddK
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
33
,
128
,
1
},
{
128
,
128
,
129
,
128
,
1
},
};
this
->
Run
();
}
// If kernel B1Layout is RowMajor, expect not to support odd O size
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_OddO
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
32
,
129
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
DISABLED_Bench_FP16
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
...
...
@@ -37,3 +104,58 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16)
this
->
verify_
=
false
;
this
->
Run
();
}
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
// TODO: enable KPadding tests when it is implemented
TEST
(
TestBatchedGemmSoftmaxGemmInterface
,
GemmSpecializationSizeMatch
)
{
int
P
=
120
;
// requires padding
int
Q
=
128
;
// do not require padding
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
Default
>
{}.
IsSupported
(
Q
,
Q
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MPadding
>
{}.
IsSupported
(
P
,
Q
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NPadding
>
{}.
IsSupported
(
Q
,
P
,
Q
,
Q
));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KPadding>{}.IsSupported(Q, Q, P, Q));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNPadding
>
{}.
IsSupported
(
P
,
P
,
Q
,
Q
));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKPadding>{}.IsSupported(P, Q, P, Q));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKPadding>{}.IsSupported(Q, P, P, Q));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(P, P, P, Q));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
OPadding
>
{}.
IsSupported
(
Q
,
Q
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MOPadding
>
{}.
IsSupported
(
P
,
Q
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NOPadding
>
{}.
IsSupported
(
Q
,
P
,
Q
,
P
));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KOPadding>{}.IsSupported(Q, Q, P, P));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNOPadding
>
{}.
IsSupported
(
P
,
P
,
Q
,
P
));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKOPadding>{}.IsSupported(P, Q, P, P));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKOPadding>{}.IsSupported(Q, P, P, P));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(P, P, P, P));
// clang-format on
}
TEST
(
TestBatchedGemmSoftmaxGemmInterface
,
GemmSpecializationSizeMismatch
)
{
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
Default
>
{}.
IsSupported
(
128
,
128
,
120
,
128
));
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
// Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 130, 128));
// Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129));
// clang-format on
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
AdhocTest
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
49
,
49
,
64
,
64
,
24
},
{
64
,
49
,
64
,
64
,
24
},
{
1020
,
1020
,
64
,
128
,
24
},
{
576
,
576
,
64
,
64
,
24
},
};
this
->
bench_
=
true
;
this
->
Run
();
}
test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp
View file @
b89a88b5
...
...
@@ -4,7 +4,10 @@
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp"
#include "profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp"
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
template
<
ck
::
index_t
N
>
using
I
=
ck
::
Number
<
N
>
;
...
...
@@ -66,3 +69,121 @@ struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
}
}
};
template
<
GemmSpecialization
GemmSpec
>
struct
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ALayout
=
Row
;
using
B0Layout
=
Col
;
using
B1Layout
=
Row
;
using
CLayout
=
Row
;
using
ADataType
=
F16
;
using
B0DataType
=
F16
;
using
B1DataType
=
F16
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
using
CDataType
=
F16
;
using
AElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
PassThrough
;
using
B1ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
using
DeviceGemmGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
ALayout
,
B0Layout
,
B1Layout
,
CLayout
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
128
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
>
;
// CShuffleBlockTransferScalarPerVector_NPerBlock
bool
IsSupported
(
int
M
,
int
N
,
int
K
,
int
O
)
{
auto
gemm
=
DeviceGemmGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
nullptr
),
static_cast
<
B0DataType
*>
(
nullptr
),
static_cast
<
B1DataType
*>
(
nullptr
),
static_cast
<
CDataType
*>
(
nullptr
),
M
,
N
,
K
,
O
,
0
,
// BatchCount
0
,
// StrideA
0
,
// StrideB0
0
,
// StrideB1
0
,
// StrideC
0
,
// BatchStrideA
0
,
// BatchStrideB0
0
,
// BatchStrideB1
0
,
// BatchStrideC
PassThrough
{},
// a_element_op
PassThrough
{},
// b0_element_op
PassThrough
{},
// acc0_element_op
PassThrough
{},
// b1_element_op
PassThrough
{});
// c_element_op
return
gemm
.
IsSupportedArgument
(
argument
);
}
};
test/data_type/CMakeLists.txt
0 → 100644
View file @
b89a88b5
if
(
USE_BITINT_EXTENSION_INT4
)
add_gtest_executable
(
test_int4 int4.cpp
)
target_link_libraries
(
test_int4 PRIVATE utility
)
endif
()
test/data_type/int4.cpp
0 → 100644
View file @
b89a88b5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <bitset>
#include <cinttypes>
#include <cstdint>
#include <iomanip>
#include "gtest/gtest.h"
#include <hip/hip_runtime.h>
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/get_id.hpp"
#include "ck/library/utility/device_memory.hpp"
using
ck
::
int4_t
;
TEST
(
Int4
,
BaseArithmetic
)
{
int4_t
a
{
1
};
int4_t
b
{
-
2
};
EXPECT_EQ
(
a
+
a
,
int4_t
{
2
});
EXPECT_EQ
(
a
-
a
,
int4_t
{
0
});
EXPECT_EQ
(
a
+
b
,
int4_t
{
-
1
});
EXPECT_EQ
(
a
-
b
,
int4_t
{
3
});
EXPECT_EQ
(
a
*
a
,
int4_t
{
1
});
EXPECT_EQ
(
a
*
b
,
int4_t
{
-
2
});
EXPECT_EQ
(
b
*
b
,
int4_t
{
4
});
EXPECT_EQ
(
a
/
b
,
int4_t
{
0
});
a
=
int4_t
{
4
};
EXPECT_EQ
(
a
/
b
,
int4_t
{
-
2
});
b
=
int4_t
{
2
};
EXPECT_EQ
(
a
%
b
,
int4_t
{
0
});
}
TEST
(
Int4
,
NumericLimits
)
{
EXPECT_EQ
(
ck
::
NumericLimits
<
int4_t
>::
Min
(),
int4_t
{
-
8
});
EXPECT_EQ
(
ck
::
NumericLimits
<
int4_t
>::
Max
(),
int4_t
{
7
});
EXPECT_EQ
(
ck
::
NumericLimits
<
int4_t
>::
Lowest
(),
int4_t
{
-
8
});
}
TEST
(
Int4
,
MathOpsV2
)
{
int4_t
a
{
4
};
int4_t
b
{
-
5
};
EXPECT_EQ
(
ck
::
math
::
abs
(
a
),
int4_t
{
4
});
EXPECT_EQ
(
ck
::
math
::
abs
(
b
),
int4_t
{
5
});
EXPECT_FALSE
(
ck
::
math
::
isnan
(
b
));
}
namespace
{
__global__
void
copy
(
const
int4_t
*
src
,
std
::
int8_t
*
dst
,
ck
::
index_t
N
)
{
ck
::
index_t
tid
=
ck
::
get_thread_global_1d_id
();
const
int8_t
*
src_i8
=
reinterpret_cast
<
const
int8_t
*>
(
src
);
if
(
tid
<
N
)
{
for
(
ck
::
index_t
i
=
tid
;
i
<
N
;
i
+=
ck
::
get_grid_size
())
{
dst
[
i
]
=
src_i8
[
i
];
}
}
}
__global__
void
copy_with_static_cast
(
const
int4_t
*
src
,
std
::
int8_t
*
dst
,
ck
::
index_t
N
)
{
ck
::
index_t
tid
=
ck
::
get_thread_global_1d_id
();
if
(
tid
<
N
)
{
for
(
ck
::
index_t
i
=
tid
;
i
<
N
;
i
+=
ck
::
get_grid_size
())
{
dst
[
i
]
=
static_cast
<
std
::
int8_t
>
(
src
[
i
]);
}
}
}
}
// anonymous namespace
TEST
(
Int4
,
CopyAsI8PositiveValue
)
{
constexpr
std
::
size_t
SIZE
=
100
;
std
::
vector
<
int4_t
>
h_src_i4
(
SIZE
,
7
);
std
::
vector
<
std
::
int8_t
>
h_src_i8
(
SIZE
,
7
);
std
::
vector
<
std
::
int8_t
>
h_dst_i8
(
SIZE
,
0
);
DeviceMem
d_src_i4
(
h_src_i4
.
size
()
*
sizeof
(
int4_t
));
DeviceMem
d_dst_i8
(
h_dst_i8
.
size
()
*
sizeof
(
std
::
int8_t
));
d_src_i4
.
SetZero
();
d_dst_i8
.
SetZero
();
d_src_i4
.
ToDevice
(
h_src_i4
.
data
());
copy
<<<
1
,
64
>>>
(
reinterpret_cast
<
const
int4_t
*>
(
d_src_i4
.
GetDeviceBuffer
()),
reinterpret_cast
<
std
::
int8_t
*>
(
d_dst_i8
.
GetDeviceBuffer
()),
SIZE
);
hip_check_error
(
hipDeviceSynchronize
());
d_dst_i8
.
FromDevice
(
h_dst_i8
.
data
());
for
(
std
::
size_t
i
=
0
;
i
<
SIZE
;
++
i
)
{
EXPECT_EQ
(
h_src_i8
[
i
],
h_dst_i8
[
i
]);
}
}
TEST
(
Int4
,
DISABLED_CopyAsI8NegativeValue
)
{
constexpr
std
::
size_t
SIZE
=
32
;
std
::
vector
<
int4_t
>
h_src_i4
(
SIZE
,
-
8
);
std
::
vector
<
std
::
int8_t
>
h_src_i8
(
SIZE
,
-
8
);
std
::
vector
<
std
::
int8_t
>
h_dst_i8
(
SIZE
,
0
);
DeviceMem
d_src_i4
(
h_src_i4
.
size
()
*
sizeof
(
int4_t
));
DeviceMem
d_dst_i8
(
h_dst_i8
.
size
()
*
sizeof
(
std
::
int8_t
));
d_src_i4
.
SetZero
();
d_dst_i8
.
SetZero
();
d_src_i4
.
ToDevice
(
h_src_i4
.
data
());
copy
<<<
1
,
64
>>>
(
reinterpret_cast
<
const
int4_t
*>
(
d_src_i4
.
GetDeviceBuffer
()),
reinterpret_cast
<
std
::
int8_t
*>
(
d_dst_i8
.
GetDeviceBuffer
()),
SIZE
);
hip_check_error
(
hipDeviceSynchronize
());
d_dst_i8
.
FromDevice
(
h_dst_i8
.
data
());
for
(
std
::
size_t
i
=
0
;
i
<
SIZE
;
++
i
)
{
EXPECT_EQ
(
h_src_i8
[
i
],
h_dst_i8
[
i
]);
}
}
TEST
(
Int4
,
CopyAsI8NegativeValueStaticCast
)
{
constexpr
std
::
size_t
SIZE
=
32
;
std
::
vector
<
int4_t
>
h_src_i4
(
SIZE
,
-
8
);
std
::
vector
<
std
::
int8_t
>
h_src_i8
(
SIZE
,
-
8
);
std
::
vector
<
std
::
int8_t
>
h_dst_i8
(
SIZE
,
0
);
DeviceMem
d_src_i4
(
h_src_i4
.
size
()
*
sizeof
(
int4_t
));
DeviceMem
d_dst_i8
(
h_dst_i8
.
size
()
*
sizeof
(
std
::
int8_t
));
d_src_i4
.
SetZero
();
d_dst_i8
.
SetZero
();
d_src_i4
.
ToDevice
(
h_src_i4
.
data
());
copy_with_static_cast
<<<
1
,
64
>>>
(
reinterpret_cast
<
const
int4_t
*>
(
d_src_i4
.
GetDeviceBuffer
()),
reinterpret_cast
<
std
::
int8_t
*>
(
d_dst_i8
.
GetDeviceBuffer
()),
SIZE
);
hip_check_error
(
hipDeviceSynchronize
());
d_dst_i8
.
FromDevice
(
h_dst_i8
.
data
());
for
(
std
::
size_t
i
=
0
;
i
<
SIZE
;
++
i
)
{
EXPECT_EQ
(
h_src_i8
[
i
],
h_dst_i8
[
i
]);
}
}
TEST
(
Int4
,
DISABLED_BitwiseRepresentation
)
{
using
bit8_t
=
std
::
bitset
<
8
>
;
int4_t
a_i4
{
3
};
std
::
int8_t
a_i8
=
*
reinterpret_cast
<
std
::
int8_t
*>
(
&
a_i4
);
std
::
int8_t
b_i8
{
3
};
#if 0
std::cout << std::hex << std::showbase << static_cast<int32_t>(a_i8)
<< ", " << static_cast<int32_t>(b_i8) << std::endl;
#endif
EXPECT_EQ
(
bit8_t
{
static_cast
<
std
::
uint64_t
>
(
a_i8
)},
bit8_t
{
static_cast
<
std
::
uint64_t
>
(
b_i8
)});
a_i4
=
int4_t
{
-
3
};
a_i8
=
*
reinterpret_cast
<
std
::
int8_t
*>
(
&
a_i4
);
b_i8
=
std
::
int8_t
{
-
3
};
#if 0
std::cout << std::hex << std::showbase << static_cast<int32_t>(a_i8)
<< ", " << static_cast<int32_t>(b_i8) << std::endl;
#endif
EXPECT_EQ
(
bit8_t
{
static_cast
<
std
::
uint64_t
>
(
a_i8
)},
bit8_t
{
static_cast
<
std
::
uint64_t
>
(
b_i8
)});
}
TEST
(
Int4
,
BitwiseRepresentationStaticCast
)
{
using
bit8_t
=
std
::
bitset
<
8
>
;
int4_t
a_i4
{
3
};
std
::
int8_t
a_i8
=
static_cast
<
std
::
int8_t
>
(
a_i4
);
std
::
int8_t
b_i8
{
3
};
#if 0
std::cout << std::hex << std::showbase << static_cast<int32_t>(a_i8)
<< ", " << static_cast<int32_t>(b_i8) << std::endl;
#endif
EXPECT_EQ
(
bit8_t
{
static_cast
<
std
::
uint64_t
>
(
a_i8
)},
bit8_t
{
static_cast
<
std
::
uint64_t
>
(
b_i8
)});
a_i4
=
int4_t
{
-
3
};
a_i8
=
static_cast
<
std
::
int8_t
>
(
a_i4
);
b_i8
=
std
::
int8_t
{
-
3
};
#if 0
std::cout << std::hex << std::showbase << static_cast<int32_t>(a_i8)
<< ", " << static_cast<int32_t>(b_i8) << std::endl;
#endif
EXPECT_EQ
(
bit8_t
{
static_cast
<
std
::
uint64_t
>
(
a_i8
)},
bit8_t
{
static_cast
<
std
::
uint64_t
>
(
b_i8
)});
}
test/gemm_split_k/gemm_split_k.cpp
View file @
b89a88b5
...
...
@@ -8,7 +8,6 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp"
...
...
Prev
1
…
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment