Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ee768148
Unverified
Commit
ee768148
authored
Jul 18, 2024
by
Qianfeng
Committed by
GitHub
Jul 17, 2024
Browse files
Replace the using of __expf by __ocml_exp_f32 to work-around the test_softmax_rank4 failure (#1394)
parent
9cac2827
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
13 additions
and
10 deletions
+13
-10
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+3
-3
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+2
-2
include/ck_tile/core/numeric/bfloat16.hpp
include/ck_tile/core/numeric/bfloat16.hpp
+4
-1
include/ck_tile/core/numeric/float8.hpp
include/ck_tile/core/numeric/float8.hpp
+2
-2
include/ck_tile/core/numeric/half.hpp
include/ck_tile/core/numeric/half.hpp
+1
-1
include/ck_tile/core/numeric/math.hpp
include/ck_tile/core/numeric/math.hpp
+1
-1
No files found.
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
ee768148
...
...
@@ -431,7 +431,7 @@ struct Relu
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// host code use higher accuracy "exp" and "div"
// gpu code use lower accuracy "_
_expf
" and "rcp" function
// gpu code use lower accuracy "_
ocml_exp_f32
" and "rcp" function
struct
FastGelu
{
template
<
typename
Y
,
typename
X
>
...
...
@@ -451,7 +451,7 @@ struct FastGelu
y
=
x
/
(
1.
f
+
emu
);
}
// device code, use lower precision "__
expf
" and "rcp"
// device code, use lower precision "__
ocml_exp_f32
" and "rcp"
template
<
>
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
...
...
@@ -459,7 +459,7 @@ struct FastGelu
const
float
c1
=
-
2.0
*
0.035677
f
;
const
float
c2
=
-
2.0
*
0.797885
f
;
const
float
u
=
x
*
(
c1
*
x
*
x
+
c2
);
const
float
emu
=
__
expf
(
u
);
const
float
emu
=
__
ocml_exp_f32
(
u
);
y
=
x
*
ck
::
math
::
rcp
(
1.
f
+
emu
);
}
...
...
include/ck/utility/math_v2.hpp
View file @
ee768148
...
...
@@ -839,7 +839,7 @@ inline __device__ T rcp(T x)
template
<
typename
T
>
inline
__device__
T
exp
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
__
expf
(
ck
::
type_convert
<
float
>
(
x
)));
return
ck
::
type_convert
<
T
>
(
__
ocml_exp_f32
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
...
...
@@ -851,7 +851,7 @@ inline __device__ half_t exp<half_t>(half_t x)
template
<
>
inline
__device__
float
exp
<
float
>
(
float
x
)
{
return
__
expf
(
x
);
return
__
ocml_exp_f32
(
x
);
};
template
<
>
...
...
include/ck_tile/core/numeric/bfloat16.hpp
View file @
ee768148
...
...
@@ -331,7 +331,10 @@ bfloat16_t sqrt(bfloat16_t x)
};
CK_TILE_DEVICE
bfloat16_t
exp
(
bfloat16_t
x
)
{
return
static_cast
<
bfloat16_t
>
(
__expf
(
static_cast
<
float
>
(
x
)));
};
bfloat16_t
exp
(
bfloat16_t
x
)
{
return
static_cast
<
bfloat16_t
>
(
__ocml_exp_f32
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
bfloat16_t
exp2
(
bfloat16_t
x
)
{
return
static_cast
<
bfloat16_t
>
(
exp2f
(
static_cast
<
float
>
(
x
)));
};
...
...
include/ck_tile/core/numeric/float8.hpp
View file @
ee768148
...
...
@@ -835,7 +835,7 @@ CK_TILE_DEVICE
fp8_t
sqrt
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
fp8_t
exp
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__
expf
(
static_cast
<
float
>
(
x
)));
};
fp8_t
exp
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__
ocml_exp_f32
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
fp8_t
exp2
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
exp2f
(
static_cast
<
float
>
(
x
)));
};
...
...
@@ -860,7 +860,7 @@ CK_TILE_DEVICE
bf8_t
sqrt
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
bf8_t
exp
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__
expf
(
static_cast
<
float
>
(
x
)));
};
bf8_t
exp
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__
ocml_exp_f32
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
bf8_t
exp2
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
exp2f
(
static_cast
<
float
>
(
x
)));
};
...
...
include/ck_tile/core/numeric/half.hpp
View file @
ee768148
...
...
@@ -374,7 +374,7 @@ half_t sqrt(half_t x)
};
CK_TILE_DEVICE
half_t
exp
(
half_t
x
)
{
return
static_cast
<
half_t
>
(
__
expf
(
static_cast
<
float
>
(
x
)));
};
half_t
exp
(
half_t
x
)
{
return
static_cast
<
half_t
>
(
__
ocml_exp_f32
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
half_t
exp2
(
half_t
x
)
{
return
static_cast
<
half_t
>
(
exp2f
(
static_cast
<
float
>
(
x
)));
};
...
...
include/ck_tile/core/numeric/math.hpp
View file @
ee768148
...
...
@@ -519,7 +519,7 @@ CK_TILE_DEVICE
double
sqrt
(
double
x
)
{
return
__builtin_amdgcn_sqrt
(
x
);
};
CK_TILE_DEVICE
float
exp
(
float
x
)
{
return
__
expf
(
x
);
};
float
exp
(
float
x
)
{
return
__
ocml_exp_f32
(
x
);
};
CK_TILE_HOST
float
exp
(
float
x
)
{
return
std
::
expf
(
x
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment