Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
89e1ebd4
Commit
89e1ebd4
authored
Nov 16, 2021
by
Jing Zhang
Browse files
updated bfloat16_to_float
parent
3737bb03
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
93 additions
and
165 deletions
+93
-165
composable_kernel/include/utility/config.hpp
composable_kernel/include/utility/config.hpp
+0
-1
composable_kernel/include/utility/data_type.hpp
composable_kernel/include/utility/data_type.hpp
+54
-2
composable_kernel/include/utility/inner_product.hpp
composable_kernel/include/utility/inner_product.hpp
+6
-0
external/rocm/include/bfloat16_dev.hpp
external/rocm/include/bfloat16_dev.hpp
+0
-125
host/driver_offline/src/conv_fwd_driver_offline.cpp
host/driver_offline/src/conv_fwd_driver_offline.cpp
+6
-6
host/host_tensor/include/host_gemm.hpp
host/host_tensor/include/host_gemm.hpp
+16
-16
host/host_tensor/include/host_tensor.hpp
host/host_tensor/include/host_tensor.hpp
+7
-12
host/host_tensor/include/host_tensor_generator.hpp
host/host_tensor/include/host_tensor_generator.hpp
+4
-3
No files found.
composable_kernel/include/utility/config.hpp
View file @
89e1ebd4
...
@@ -5,7 +5,6 @@
...
@@ -5,7 +5,6 @@
#include "hip/hip_runtime.h"
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "hip/hip_fp16.h"
#endif
#endif
#include "bfloat16_dev.hpp"
// "Constant" address space for kernel parameter
// "Constant" address space for kernel parameter
#define CONSTANT __attribute__((address_space(4)))
#define CONSTANT __attribute__((address_space(4)))
...
...
composable_kernel/include/utility/data_type.hpp
View file @
89e1ebd4
...
@@ -927,6 +927,58 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
...
@@ -927,6 +927,58 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
__host__
__device__
float
bf16_to_f32
(
ushort
src_val
)
{
union
{
uint32_t
int32
;
float
fp32
;
}
u
=
{
uint32_t
(
src_val
)
<<
16
};
return
u
.
fp32
;
}
__host__
__device__
ushort
f32_to_bf16
(
float
src_val
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
src_val
};
if
(
~
u
.
int32
&
0x7f800000
)
{
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
u
.
int32
+=
0x7fff
+
((
u
.
int32
>>
16
)
&
1
);
// Round to nearest, round to even
}
else
if
(
u
.
int32
&
0xffff
)
{
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bloat16's mantissa bits are all 0.
u
.
int32
|=
0x10000
;
// Preserve signaling NaN
}
return
uint16_t
(
u
.
int32
>>
16
);
}
// data type conversion
// data type conversion
template
<
typename
T
>
template
<
typename
T
>
struct
type_convert
struct
type_convert
...
@@ -942,14 +994,14 @@ template <>
...
@@ -942,14 +994,14 @@ template <>
template
<
>
template
<
>
__device__
float
type_convert
<
float
>::
operator
()
<
ushort
>
(
ushort
x
)
const
__device__
float
type_convert
<
float
>::
operator
()
<
ushort
>
(
ushort
x
)
const
{
{
return
bf
loat
16_to_f
loat
(
x
);
return
bf16_to_f
32
(
x
);
}
}
template
<
>
template
<
>
template
<
>
template
<
>
__device__
ushort
type_convert
<
ushort
>::
operator
()
<
float
>
(
float
x
)
const
__device__
ushort
type_convert
<
ushort
>::
operator
()
<
float
>
(
float
x
)
const
{
{
return
f
loat
_to_bf
loat
16
(
x
);
return
f
32
_to_bf16
(
x
);
}
}
// TODO: deprecate this
// TODO: deprecate this
...
...
composable_kernel/include/utility/inner_product.hpp
View file @
89e1ebd4
...
@@ -28,6 +28,12 @@ __device__ void inner_product<float, float, float>(const float& a, const float&
...
@@ -28,6 +28,12 @@ __device__ void inner_product<float, float, float>(const float& a, const float&
#endif
#endif
}
}
template
<
>
__device__
void
inner_product
<
ushort
,
ushort
,
float
>
(
const
ushort
&
a
,
const
ushort
&
b
,
float
&
c
)
{
c
+=
bf16_to_f32
(
a
)
*
bf16_to_f32
(
b
);
}
template
<
>
template
<
>
__device__
void
__device__
void
inner_product
<
float2_t
,
float2_t
,
float
>
(
const
float2_t
&
a
,
const
float2_t
&
b
,
float
&
c
)
inner_product
<
float2_t
,
float2_t
,
float
>
(
const
float2_t
&
a
,
const
float2_t
&
b
,
float
&
c
)
...
...
external/rocm/include/bfloat16_dev.hpp
deleted
100644 → 0
View file @
3737bb03
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2019 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef BFLOAT16_DEVICE_HPP
#define BFLOAT16_DEVICE_HPP
#ifdef __cplusplus
extern
"C"
{
#endif
#ifdef __HIP_PLATFORM_HCC__
#define EXECUTION_SPECIFIER __device__ __host__
#else
#define EXECUTION_SPECIFIER
#endif // MIOPEN_BACKEND_HIP
typedef
union
{
uint
u32
;
ushort2
ushortx2
;
// Composable kernels are written in HIP language. The language doesnt support
// ushort2.hi or ushort2.low.
#ifdef __HIP_PLATFORM_HCC__
ushort
ushortvec
[
2
];
#endif // MIOPEN_BACKEND_HIP
float
f32
;
}
cvt_bf16_fp32_t
;
EXECUTION_SPECIFIER
float
bfloat16_to_float
(
ushort
src_val
)
{
cvt_bf16_fp32_t
target_val
;
#ifdef __HIP_PLATFORM_HCC__
target_val
.
ushortx2
=
make_ushort2
(
0
,
src_val
);
#else
target_val
.
ushortx2
=
(
ushort2
)(
0
,
src_val
);
#endif
return
target_val
.
f32
;
}
EXECUTION_SPECIFIER
ushort
float_to_bfloat16
(
float
src_val
)
{
cvt_bf16_fp32_t
target_val
;
target_val
.
f32
=
src_val
;
// BF16 round and NaN preservation code matches
// https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/include/rocblas_bfloat16.h
if
((
~
target_val
.
u32
&
0x7f800000
)
==
0
)
// Inf or NaN
{
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bloat16's mantissa bits are all 0.
if
((
target_val
.
u32
&
0xffff
)
!=
0
)
{
target_val
.
u32
|=
0x10000
;
// Preserve signaling NaN
}
}
else
{
#ifdef MIOPEN_USE_RNE_BFLOAT16
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
#ifdef __HIP_PLATFORM_HCC__
target_val
.
u32
+=
(
0x7fff
+
(
target_val
.
ushortvec
[
1
]
&
1
));
#else
target_val
.
u32
+=
(
0x7fff
+
(
target_val
.
ushortx2
.
hi
&
1
));
// Round to nearest, round to even
#endif // MIOPEN_BACKEND_HIP
#endif // MIOPEN_USE_RNE_BFLOAT16
}
#ifdef __HIP_PLATFORM_HCC__
return
target_val
.
ushortvec
[
1
];
#else
return
target_val
.
ushortx2
.
hi
;
#endif // MIOPEN_BACKEND_HIP
}
#ifdef __cplusplus
}
#endif
#endif // BFLOAT16_DEVICE_HPP
host/driver_offline/src/conv_fwd_driver_offline.cpp
View file @
89e1ebd4
...
@@ -82,8 +82,8 @@ void host_convolution_forward(const Tensor<TIn>& in,
...
@@ -82,8 +82,8 @@ void host_convolution_forward(const Tensor<TIn>& in,
{
{
if
constexpr
(
is_same
<
TIn
,
ushort
>::
value
)
if
constexpr
(
is_same
<
TIn
,
ushort
>::
value
)
{
{
v
+=
bfloat
16_to_f
loat
(
in
(
n
,
c
,
hi
,
wi
))
*
v
+=
ck
::
bf
16_to_f
32
(
in
(
n
,
c
,
hi
,
wi
))
*
bfloat
16_to_f
loat
(
wei
(
k
,
c
,
y
,
x
));
ck
::
bf
16_to_f
32
(
wei
(
k
,
c
,
y
,
x
));
}
}
else
else
{
{
...
@@ -97,7 +97,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
...
@@ -97,7 +97,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
if
constexpr
(
is_same
<
TOut
,
ushort
>::
value
)
if
constexpr
(
is_same
<
TOut
,
ushort
>::
value
)
{
{
out
(
n
,
k
,
ho
,
wo
)
=
f
loat
_to_bf
loat
16
(
v
);
out
(
n
,
k
,
ho
,
wo
)
=
f
32
_to_bf16
(
v
);
}
}
else
else
{
{
...
@@ -120,8 +120,8 @@ void host_convolution_forward(const Tensor<TIn>& in,
...
@@ -120,8 +120,8 @@ void host_convolution_forward(const Tensor<TIn>& in,
{
{
if
constexpr
(
is_same
<
TIn
,
ushort
>::
value
)
if
constexpr
(
is_same
<
TIn
,
ushort
>::
value
)
{
{
v
+=
bfloat
16_to_f
loat
(
in
(
n
,
hi
,
wi
,
c
))
*
v
+=
ck
::
bf
16_to_f
32
(
in
(
n
,
hi
,
wi
,
c
))
*
bfloat
16_to_f
loat
(
wei
(
k
,
y
,
x
,
c
));
ck
::
bf
16_to_f
32
(
wei
(
k
,
y
,
x
,
c
));
}
}
else
else
{
{
...
@@ -134,7 +134,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
...
@@ -134,7 +134,7 @@ void host_convolution_forward(const Tensor<TIn>& in,
}
}
if
constexpr
(
is_same
<
TOut
,
ushort
>::
value
)
if
constexpr
(
is_same
<
TOut
,
ushort
>::
value
)
{
{
out
(
n
,
ho
,
wo
,
k
)
=
f
loat
_to_bf
loat
16
(
v
);
out
(
n
,
ho
,
wo
,
k
)
=
f
32
_to_bf16
(
v
);
}
}
else
else
{
{
...
...
host/host_tensor/include/host_gemm.hpp
View file @
89e1ebd4
...
@@ -16,10 +16,10 @@ void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
...
@@ -16,10 +16,10 @@ void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
v
+=
bfloat
16_to_f
loat
(
a
(
m
,
k
))
*
bfloat
16_to_f
loat
(
b
(
k
,
n
));
v
+=
ck
::
bf
16_to_f
32
(
a
(
m
,
k
))
*
ck
::
bf
16_to_f
32
(
b
(
k
,
n
));
}
}
c
(
m
,
n
)
=
float
_to_bf
loat
16
(
v
);
c
(
m
,
n
)
=
ck
::
f32
_to_bf16
(
v
);
};
};
make_ParallelTensorFunctor
(
f_mk_kn_mn
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
make_ParallelTensorFunctor
(
f_mk_kn_mn
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
...
@@ -34,10 +34,10 @@ void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
...
@@ -34,10 +34,10 @@ void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
v
+=
bfloat
16_to_f
loat
(
a
(
m
,
k
))
*
bfloat
16_to_f
loat
(
b
(
n
,
k
));
v
+=
ck
::
bf
16_to_f
32
(
a
(
m
,
k
))
*
ck
::
bf
16_to_f
32
(
b
(
n
,
k
));
}
}
c
(
m
,
n
)
=
float
_to_bf
loat
16
(
v
);
c
(
m
,
n
)
=
ck
::
f32
_to_bf16
(
v
);
};
};
make_ParallelTensorFunctor
(
f_mk_nk_mn
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
make_ParallelTensorFunctor
(
f_mk_nk_mn
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
...
@@ -52,10 +52,10 @@ void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
...
@@ -52,10 +52,10 @@ void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
v
+=
bfloat
16_to_f
loat
(
a
(
k
,
m
))
*
bfloat
16_to_f
loat
(
b
(
k
,
n
));
v
+=
ck
::
bf
16_to_f
32
(
a
(
k
,
m
))
*
ck
::
bf
16_to_f
32
(
b
(
k
,
n
));
}
}
c
(
m
,
n
)
=
float
_to_bf
loat
16
(
v
);
c
(
m
,
n
)
=
ck
::
f32
_to_bf16
(
v
);
};
};
make_ParallelTensorFunctor
(
f_km_kn_mn
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
make_ParallelTensorFunctor
(
f_km_kn_mn
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
...
@@ -70,10 +70,10 @@ void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
...
@@ -70,10 +70,10 @@ void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
v
+=
bfloat
16_to_f
loat
(
a
(
k
,
m
))
*
bfloat
16_to_f
loat
(
b
(
n
,
k
));
v
+=
ck
::
bf
16_to_f
32
(
a
(
k
,
m
))
*
ck
::
bf
16_to_f
32
(
b
(
n
,
k
));
}
}
c
(
m
,
n
)
=
float
_to_bf
loat
16
(
v
);
c
(
m
,
n
)
=
ck
::
f32
_to_bf16
(
v
);
};
};
make_ParallelTensorFunctor
(
f_km_nk_mn
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
make_ParallelTensorFunctor
(
f_km_nk_mn
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
...
@@ -88,10 +88,10 @@ void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
...
@@ -88,10 +88,10 @@ void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
v
+=
bfloat
16_to_f
loat
(
a
(
m
,
k
))
*
bfloat
16_to_f
loat
(
b
(
k
,
n
));
v
+=
ck
::
bf
16_to_f
32
(
a
(
m
,
k
))
*
ck
::
bf
16_to_f
32
(
b
(
k
,
n
));
}
}
c
(
n
,
m
)
=
float
_to_bf
loat
16
(
v
);
c
(
n
,
m
)
=
ck
::
f32
_to_bf16
(
v
);
};
};
make_ParallelTensorFunctor
(
f_mk_kn_nm
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
make_ParallelTensorFunctor
(
f_mk_kn_nm
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
...
@@ -106,10 +106,10 @@ void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
...
@@ -106,10 +106,10 @@ void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
v
+=
bfloat
16_to_f
loat
(
a
(
m
,
k
))
*
bfloat
16_to_f
loat
(
b
(
n
,
k
));
v
+=
ck
::
bf
16_to_f
32
(
a
(
m
,
k
))
*
ck
::
bf
16_to_f
32
(
b
(
n
,
k
));
}
}
c
(
n
,
m
)
=
float
_to_bf
loat
16
(
v
);
c
(
n
,
m
)
=
ck
::
f32
_to_bf16
(
v
);
};
};
make_ParallelTensorFunctor
(
f_mk_nk_nm
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
make_ParallelTensorFunctor
(
f_mk_nk_nm
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
...
@@ -124,10 +124,10 @@ void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
...
@@ -124,10 +124,10 @@ void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
v
+=
bfloat
16_to_f
loat
(
a
(
k
,
m
))
*
bfloat
16_to_f
loat
(
b
(
k
,
n
));
v
+=
ck
::
bf
16_to_f
32
(
a
(
k
,
m
))
*
ck
::
bf
16_to_f
32
(
b
(
k
,
n
));
}
}
c
(
n
,
m
)
=
float
_to_bf
loat
16
(
v
);
c
(
n
,
m
)
=
ck
::
f32
_to_bf16
(
v
);
};
};
make_ParallelTensorFunctor
(
f_km_kn_nm
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
make_ParallelTensorFunctor
(
f_km_kn_nm
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
...
@@ -142,10 +142,10 @@ void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
...
@@ -142,10 +142,10 @@ void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
v
+=
bfloat
16_to_f
loat
(
a
(
k
,
m
))
*
bfloat
16_to_f
loat
(
b
(
n
,
k
));
v
+=
ck
::
bf
16_to_f
32
(
a
(
k
,
m
))
*
ck
::
bf
16_to_f
32
(
b
(
n
,
k
));
}
}
c
(
n
,
m
)
=
float
_to_bf
loat
16
(
v
);
c
(
n
,
m
)
=
ck
::
f32
_to_bf16
(
v
);
};
};
make_ParallelTensorFunctor
(
f_km_nk_nm
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
make_ParallelTensorFunctor
(
f_km_nk_nm
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
...
...
host/host_tensor/include/host_tensor.hpp
View file @
89e1ebd4
...
@@ -321,18 +321,14 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
...
@@ -321,18 +321,14 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
std
::
cout
<<
"max_diff: "
<<
max_diff
<<
", "
<<
ref_value
<<
", "
<<
result_value
<<
std
::
endl
;
std
::
cout
<<
"max_diff: "
<<
max_diff
<<
", "
<<
ref_value
<<
", "
<<
result_value
<<
std
::
endl
;
}
}
float
bf16_to_f32
(
ushort
src_val
)
__host__
__device__
float
bf16_to_f32
(
ushort
src_val
)
{
{
typedef
union
union
{
{
ushort
x
,
y
;
uint32_t
int32
;
float
f32
;
float
fp32
;
}
bf16_f32_t
;
}
u
=
{
uint32_t
(
src_val
)
<<
16
};
return
u
.
fp32
;
bf16_f32_t
v
;
v
.
x
=
0
;
v
.
y
=
src_val
;
return
v
.
f32
;
}
}
template
<
>
template
<
>
...
@@ -354,8 +350,7 @@ void check_error<ushort>(const Tensor<ushort>& ref, const Tensor<ushort>& result
...
@@ -354,8 +350,7 @@ void check_error<ushort>(const Tensor<ushort>& ref, const Tensor<ushort>& result
}
}
std
::
cout
<<
"error: "
<<
error
<<
std
::
endl
;
std
::
cout
<<
"error: "
<<
error
<<
std
::
endl
;
std
::
cout
<<
"max_diff: "
<<
max_diff
<<
", ref: "
<<
ref_value
<<
", res: "
<<
result_value
std
::
cout
<<
"max_diff: "
<<
max_diff
<<
", "
<<
ref_value
<<
", "
<<
result_value
<<
std
::
endl
;
<<
std
::
endl
;
}
}
#endif
#endif
host/host_tensor/include/host_tensor_generator.hpp
View file @
89e1ebd4
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include <cmath>
#include <cmath>
#include "config.hpp"
#include "config.hpp"
#include "data_type.hpp"
template
<
typename
T
>
template
<
typename
T
>
struct
GeneratorTensor_1
struct
GeneratorTensor_1
...
@@ -24,7 +25,7 @@ struct GeneratorTensor_1<ushort>
...
@@ -24,7 +25,7 @@ struct GeneratorTensor_1<ushort>
template
<
typename
...
Is
>
template
<
typename
...
Is
>
ushort
operator
()(
Is
...)
ushort
operator
()(
Is
...)
{
{
return
float
_to_bf
loat
16
(
value
);
return
ck
::
f32
_to_bf16
(
value
);
}
}
};
};
...
@@ -74,7 +75,7 @@ struct GeneratorTensor_2<ushort>
...
@@ -74,7 +75,7 @@ struct GeneratorTensor_2<ushort>
ushort
operator
()(
Is
...)
ushort
operator
()(
Is
...)
{
{
float
tmp
=
(
std
::
rand
()
%
(
max_value
-
min_value
))
+
min_value
;
float
tmp
=
(
std
::
rand
()
%
(
max_value
-
min_value
))
+
min_value
;
return
float
_to_bf
loat
16
(
tmp
);
return
ck
::
f32
_to_bf16
(
tmp
);
}
}
};
};
...
@@ -119,7 +120,7 @@ struct GeneratorTensor_3<ushort>
...
@@ -119,7 +120,7 @@ struct GeneratorTensor_3<ushort>
float
fp32_tmp
=
min_value
+
tmp
*
(
max_value
-
min_value
);
float
fp32_tmp
=
min_value
+
tmp
*
(
max_value
-
min_value
);
return
float
_to_bf
loat
16
(
fp32_tmp
);
return
ck
::
f32
_to_bf16
(
fp32_tmp
);
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment