Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
bdddf1ea
Unverified
Commit
bdddf1ea
authored
Jan 18, 2025
by
Bartłomiej Kocot
Committed by
GitHub
Jan 18, 2025
Browse files
[CK_TILE] Add error threshold calculation for gemm examples (#1821)
parent
0fcbb25f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
259 additions
and
14 deletions
+259
-14
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+45
-6
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
+45
-4
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
+33
-2
include/ck_tile/core/numeric/bfloat16.hpp
include/ck_tile/core/numeric/bfloat16.hpp
+11
-1
include/ck_tile/host/check_err.hpp
include/ck_tile/host/check_err.hpp
+125
-1
No files found.
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
bdddf1ea
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
auto
calculate_rtol_atol
(
const
ck_tile
::
index_t
K
,
const
ck_tile
::
index_t
kbatch
,
const
float
max_accumulated_value
)
{
using
ComputeType
=
std
::
conditional_t
<
sizeof
(
ADataType
)
<
sizeof
(
BDataType
),
ADataType
,
BDataType
>
;
// Calculate thresholds
const
auto
rtol
=
ck_tile
::
get_relative_threshold
<
ComputeType
,
CDataType
,
AccDataType
>
(
ck_tile
::
integer_divide_ceil
(
K
,
kbatch
));
const
auto
atol
=
ck_tile
::
get_absolute_threshold
<
ComputeType
,
CDataType
,
AccDataType
>
(
max_accumulated_value
/
kbatch
,
ck_tile
::
integer_divide_ceil
(
K
,
kbatch
));
// Calculate error due to split_k accumulation
const
auto
rtol_split_k
=
ck_tile
::
get_relative_threshold
<
CDataType
,
CDataType
,
CDataType
>
(
kbatch
);
const
auto
atol_split_k
=
ck_tile
::
get_absolute_threshold
<
CDataType
,
CDataType
,
CDataType
>
(
max_accumulated_value
,
kbatch
);
// Use higher threshold
return
ck_tile
::
make_tuple
(
std
::
max
(
rtol
,
rtol_split_k
),
std
::
max
(
atol
,
atol_split_k
));
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
invoke_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
float
invoke_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
...
@@ -148,9 +168,18 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -148,9 +168,18 @@ int run_gemm_example_with_layouts(int argc,
ck_tile
::
reference_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
ck_tile
::
reference_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
a_m_k
,
b_k_n
,
c_m_n_host_ref
);
a_m_k
,
b_k_n
,
c_m_n_host_ref
);
const
float
max_accumulated_value
=
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
);
*
std
::
max_element
(
c_m_n_host_ref
.
mData
.
begin
(),
c_m_n_host_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
(
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{}));
std
::
cout
<<
"Relative error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
}
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
...
@@ -196,8 +225,18 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -196,8 +225,18 @@ int run_gemm_example_with_layouts(int argc,
ck_tile
::
hip_check_error
(
hipFree
(
d_C
));
ck_tile
::
hip_check_error
(
hipFree
(
d_C
));
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
);
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_gpu_ref
.
mData
.
begin
(),
c_m_n_gpu_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
(
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{}));
std
::
cout
<<
"Relative error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
std
::
cout
<<
"The GPU veification result is: "
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
std
::
cout
<<
"The GPU veification result is: "
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
}
...
...
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
View file @
bdddf1ea
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
auto
calculate_rtol_atol
(
const
ck_tile
::
index_t
K
,
const
ck_tile
::
index_t
kbatch
,
const
float
max_accumulated_value
)
{
using
ComputeType
=
std
::
conditional_t
<
sizeof
(
ADataType
)
<
sizeof
(
BDataType
),
ADataType
,
BDataType
>
;
// Calculate thresholds
const
auto
rtol
=
ck_tile
::
get_relative_threshold
<
ComputeType
,
CDataType
,
AccDataType
>
(
ck_tile
::
integer_divide_ceil
(
K
,
kbatch
));
const
auto
atol
=
ck_tile
::
get_absolute_threshold
<
ComputeType
,
CDataType
,
AccDataType
>
(
max_accumulated_value
/
kbatch
,
ck_tile
::
integer_divide_ceil
(
K
,
kbatch
));
// Calculate error due to split_k accumulation
const
auto
rtol_split_k
=
ck_tile
::
get_relative_threshold
<
CDataType
,
CDataType
,
CDataType
>
(
kbatch
);
const
auto
atol_split_k
=
ck_tile
::
get_absolute_threshold
<
CDataType
,
CDataType
,
CDataType
>
(
max_accumulated_value
,
kbatch
);
// Use higher threshold
return
ck_tile
::
make_tuple
(
std
::
max
(
rtol
,
rtol_split_k
),
std
::
max
(
atol
,
atol_split_k
));
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
invoke_batched_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
float
invoke_batched_gemm
(
ck_tile
::
DeviceMem
&
a_m_k_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
ck_tile
::
DeviceMem
&
b_k_n_dev_buf
,
...
@@ -179,8 +199,18 @@ int run_batched_gemm_example_with_layouts(int argc,
...
@@ -179,8 +199,18 @@ int run_batched_gemm_example_with_layouts(int argc,
ck_tile
::
reference_batched_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
ck_tile
::
reference_batched_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
a_m_k
,
b_n_k
,
c_m_n_host_ref
);
a_m_k
,
b_n_k
,
c_m_n_host_ref
);
const
float
max_accumulated_value
=
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
);
*
std
::
max_element
(
c_m_n_host_ref
.
mData
.
begin
(),
c_m_n_host_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
(
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{}));
std
::
cout
<<
"Relative error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
}
...
@@ -240,7 +270,18 @@ int run_batched_gemm_example_with_layouts(int argc,
...
@@ -240,7 +270,18 @@ int run_batched_gemm_example_with_layouts(int argc,
ck_tile
::
hip_check_error
(
hipFree
(
d_C
));
ck_tile
::
hip_check_error
(
hipFree
(
d_C
));
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
);
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_gpu_ref
.
mData
.
begin
(),
c_m_n_gpu_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
(
K
,
kbatch
,
max_accumulated_value
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{}));
std
::
cout
<<
"Relative error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
std
::
cout
<<
"The GPU verification result is: "
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
std
::
cout
<<
"The GPU verification result is: "
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
}
...
...
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
View file @
bdddf1ea
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
auto
calculate_rtol_atol
(
const
ck_tile
::
index_t
K
,
const
ck_tile
::
index_t
kbatch
,
const
float
max_accumulated_value
)
{
using
ComputeType
=
std
::
conditional_t
<
sizeof
(
ADataType
)
<
sizeof
(
BDataType
),
ADataType
,
BDataType
>
;
// Calculate thresholds
const
auto
rtol
=
ck_tile
::
get_relative_threshold
<
ComputeType
,
CDataType
,
AccDataType
>
(
ck_tile
::
integer_divide_ceil
(
K
,
kbatch
));
const
auto
atol
=
ck_tile
::
get_absolute_threshold
<
ComputeType
,
CDataType
,
AccDataType
>
(
max_accumulated_value
/
kbatch
,
ck_tile
::
integer_divide_ceil
(
K
,
kbatch
));
// Calculate error due to split_k accumulation
const
auto
rtol_split_k
=
ck_tile
::
get_relative_threshold
<
CDataType
,
CDataType
,
CDataType
>
(
kbatch
);
const
auto
atol_split_k
=
ck_tile
::
get_absolute_threshold
<
CDataType
,
CDataType
,
CDataType
>
(
max_accumulated_value
,
kbatch
);
// Use higher threshold
return
ck_tile
::
make_tuple
(
std
::
max
(
rtol
,
rtol_split_k
),
std
::
max
(
atol
,
atol_split_k
));
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
invoke_gemm
(
int
n_warmup
,
float
invoke_gemm
(
int
n_warmup
,
int
n_repeat
,
int
n_repeat
,
...
@@ -162,7 +182,18 @@ int run_grouped_gemm_example_with_layouts(int argc,
...
@@ -162,7 +182,18 @@ int run_grouped_gemm_example_with_layouts(int argc,
c_m_n_host_ref
.
SetZero
();
c_m_n_host_ref
.
SetZero
();
ck_tile
::
reference_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
ck_tile
::
reference_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
a_m_k_tensors
[
i
],
b_k_n_tensors
[
i
],
c_m_n_host_ref
);
a_m_k_tensors
[
i
],
b_k_n_tensors
[
i
],
c_m_n_host_ref
);
pass
&=
ck_tile
::
check_err
(
c_m_n_tensors
[
i
],
c_m_n_host_ref
);
const
float
max_accumulated_value
=
*
std
::
max_element
(
c_m_n_host_ref
.
mData
.
begin
(),
c_m_n_host_ref
.
mData
.
end
());
const
auto
rtol_atol
=
calculate_rtol_atol
(
Ks
[
i
],
1
/*kbatch*/
,
max_accumulated_value
);
pass
&=
ck_tile
::
check_err
(
c_m_n_tensors
[
i
],
c_m_n_host_ref
,
"Error: Incorrect results!"
,
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{}),
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{}));
std
::
cout
<<
"gemm["
<<
i
<<
"] Relative error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
0
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
}
}
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
}
...
...
include/ck_tile/core/numeric/bfloat16.hpp
View file @
bdddf1ea
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
...
@@ -376,6 +376,16 @@ struct numeric<bfloat16_t>
...
@@ -376,6 +376,16 @@ struct numeric<bfloat16_t>
}
}
};
};
template
<
typename
T
>
struct
numeric_traits
;
template
<
>
struct
numeric_traits
<
bfloat16_t
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
7
;
};
#if CK_TILE_USE_CUSTOM_DATA_TYPE
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT
(
CK_TILE_HOST_DEVICE
,
bfloat16_t
)
CK_TILE_ARITHMETIC_USING_FLOAT
(
CK_TILE_HOST_DEVICE
,
bfloat16_t
)
#endif
#endif
...
...
include/ck_tile/host/check_err.hpp
View file @
bdddf1ea
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -18,6 +18,130 @@
...
@@ -18,6 +18,130 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
double
get_relative_threshold
(
const
int
number_of_accumulations
=
1
)
{
using
F8
=
ck_tile
::
fp8_t
;
using
F16
=
ck_tile
::
half_t
;
using
BF16
=
ck_tile
::
bf16_t
;
using
F32
=
float
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
static_assert
(
std
::
is_same_v
<
ComputeDataType
,
F8
>
||
std
::
is_same_v
<
ComputeDataType
,
F16
>
||
std
::
is_same_v
<
ComputeDataType
,
BF16
>
||
std
::
is_same_v
<
ComputeDataType
,
F32
>
||
std
::
is_same_v
<
ComputeDataType
,
I8
>
||
std
::
is_same_v
<
ComputeDataType
,
I32
>
||
std
::
is_same_v
<
ComputeDataType
,
int
>
,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"
);
double
compute_error
=
0
;
if
constexpr
(
std
::
is_same_v
<
ComputeDataType
,
I8
>
||
std
::
is_same_v
<
ComputeDataType
,
I32
>
||
std
::
is_same_v
<
ComputeDataType
,
int
>
)
{
return
0
;
}
else
{
compute_error
=
std
::
pow
(
2
,
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
}
static_assert
(
std
::
is_same_v
<
OutDataType
,
F8
>
||
std
::
is_same_v
<
OutDataType
,
F16
>
||
std
::
is_same_v
<
OutDataType
,
BF16
>
||
std
::
is_same_v
<
OutDataType
,
F32
>
||
std
::
is_same_v
<
OutDataType
,
I8
>
||
std
::
is_same_v
<
OutDataType
,
I32
>
||
std
::
is_same_v
<
OutDataType
,
int
>
,
"Warning: Unhandled OutDataType for setting up the relative threshold!"
);
double
output_error
=
0
;
if
constexpr
(
std
::
is_same_v
<
OutDataType
,
I8
>
||
std
::
is_same_v
<
OutDataType
,
I32
>
||
std
::
is_same_v
<
OutDataType
,
int
>
)
{
return
0
;
}
else
{
output_error
=
std
::
pow
(
2
,
-
numeric_traits
<
OutDataType
>::
mant
)
*
0.5
;
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
std
::
is_same_v
<
AccDataType
,
F8
>
||
std
::
is_same_v
<
AccDataType
,
F16
>
||
std
::
is_same_v
<
AccDataType
,
BF16
>
||
std
::
is_same_v
<
AccDataType
,
F32
>
||
std
::
is_same_v
<
AccDataType
,
I8
>
||
std
::
is_same_v
<
AccDataType
,
I32
>
||
std
::
is_same_v
<
AccDataType
,
int
>
,
"Warning: Unhandled AccDataType for setting up the relative threshold!"
);
double
acc_error
=
0
;
if
constexpr
(
std
::
is_same_v
<
AccDataType
,
I8
>
||
std
::
is_same_v
<
AccDataType
,
I32
>
||
std
::
is_same_v
<
AccDataType
,
int
>
)
{
return
0
;
}
else
{
acc_error
=
std
::
pow
(
2
,
-
numeric_traits
<
AccDataType
>::
mant
)
*
0.5
*
number_of_accumulations
;
}
return
std
::
max
(
acc_error
,
midway_error
);
}
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
double
get_absolute_threshold
(
const
double
max_possible_num
,
const
int
number_of_accumulations
=
1
)
{
using
F8
=
ck_tile
::
fp8_t
;
using
F16
=
ck_tile
::
half_t
;
using
BF16
=
ck_tile
::
bf16_t
;
using
F32
=
float
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
static_assert
(
std
::
is_same_v
<
ComputeDataType
,
F8
>
||
std
::
is_same_v
<
ComputeDataType
,
F16
>
||
std
::
is_same_v
<
ComputeDataType
,
BF16
>
||
std
::
is_same_v
<
ComputeDataType
,
F32
>
||
std
::
is_same_v
<
ComputeDataType
,
I8
>
||
std
::
is_same_v
<
ComputeDataType
,
I32
>
||
std
::
is_same_v
<
ComputeDataType
,
int
>
,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"
);
auto
expo
=
std
::
log2
(
std
::
abs
(
max_possible_num
));
double
compute_error
=
0
;
if
constexpr
(
std
::
is_same_v
<
ComputeDataType
,
I8
>
||
std
::
is_same_v
<
ComputeDataType
,
I32
>
||
std
::
is_same_v
<
ComputeDataType
,
int
>
)
{
return
0
;
}
else
{
compute_error
=
std
::
pow
(
2
,
expo
-
numeric_traits
<
ComputeDataType
>::
mant
)
*
0.5
;
}
static_assert
(
std
::
is_same_v
<
OutDataType
,
F8
>
||
std
::
is_same_v
<
OutDataType
,
F16
>
||
std
::
is_same_v
<
OutDataType
,
BF16
>
||
std
::
is_same_v
<
OutDataType
,
F32
>
||
std
::
is_same_v
<
OutDataType
,
I8
>
||
std
::
is_same_v
<
OutDataType
,
I32
>
||
std
::
is_same_v
<
OutDataType
,
int
>
,
"Warning: Unhandled OutDataType for setting up the absolute threshold!"
);
double
output_error
=
0
;
if
constexpr
(
std
::
is_same_v
<
OutDataType
,
I8
>
||
std
::
is_same_v
<
OutDataType
,
I32
>
||
std
::
is_same_v
<
OutDataType
,
int
>
)
{
return
0
;
}
else
{
output_error
=
std
::
pow
(
2
,
expo
-
numeric_traits
<
OutDataType
>::
mant
)
*
0.5
;
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
std
::
is_same_v
<
AccDataType
,
F8
>
||
std
::
is_same_v
<
AccDataType
,
F16
>
||
std
::
is_same_v
<
AccDataType
,
BF16
>
||
std
::
is_same_v
<
AccDataType
,
F32
>
||
std
::
is_same_v
<
AccDataType
,
I8
>
||
std
::
is_same_v
<
AccDataType
,
I32
>
||
std
::
is_same_v
<
AccDataType
,
int
>
,
"Warning: Unhandled AccDataType for setting up the absolute threshold!"
);
double
acc_error
=
0
;
if
constexpr
(
std
::
is_same_v
<
AccDataType
,
I8
>
||
std
::
is_same_v
<
AccDataType
,
I32
>
||
std
::
is_same_v
<
AccDataType
,
int
>
)
{
return
0
;
}
else
{
acc_error
=
std
::
pow
(
2
,
expo
-
numeric_traits
<
AccDataType
>::
mant
)
*
0.5
*
number_of_accumulations
;
}
return
std
::
max
(
acc_error
,
midway_error
);
}
template
<
typename
T
>
template
<
typename
T
>
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
std
::
vector
<
T
>&
v
)
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
std
::
vector
<
T
>&
v
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment