Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
99af398f
Commit
99af398f
authored
Oct 19, 2023
by
Rostyslav Geyyer
Browse files
Add bf8 functionality
parent
b2b68eff
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
7 deletions
+2
-7
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+2
-7
No files found.
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
99af398f
...
@@ -113,7 +113,6 @@ struct PassThrough
...
@@ -113,7 +113,6 @@ struct PassThrough
}
}
#endif
#endif
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
f8_t
>
(
f8_t
&
y
,
const
f8_t
&
x
)
const
__host__
__device__
void
operator
()
<
f8_t
,
f8_t
>
(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
{
...
@@ -143,9 +142,7 @@ struct PassThrough
...
@@ -143,9 +142,7 @@ struct PassThrough
{
{
y
=
type_convert
<
f8_t
>
(
x
);
y
=
type_convert
<
f8_t
>
(
x
);
}
}
#endif
#if defined CK_ENABLE_BF8
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
bf8_t
,
bf8_t
>
(
bf8_t
&
y
,
const
bf8_t
&
x
)
const
__host__
__device__
void
operator
()
<
bf8_t
,
bf8_t
>
(
bf8_t
&
y
,
const
bf8_t
&
x
)
const
{
{
...
@@ -175,7 +172,6 @@ struct PassThrough
...
@@ -175,7 +172,6 @@ struct PassThrough
{
{
y
=
ck
::
type_convert
<
bf8_t
>
(
x
);
y
=
ck
::
type_convert
<
bf8_t
>
(
x
);
}
}
#endif
};
};
struct
UnaryConvert
struct
UnaryConvert
...
@@ -204,7 +200,6 @@ struct ConvertBF16RTN
...
@@ -204,7 +200,6 @@ struct ConvertBF16RTN
}
}
};
};
#if defined CK_ENABLE_FP8
struct
ConvertF8SR
struct
ConvertF8SR
{
{
// convert to fp8 using stochastic rounding (SR)
// convert to fp8 using stochastic rounding (SR)
...
@@ -212,7 +207,8 @@ struct ConvertF8SR
...
@@ -212,7 +207,8 @@ struct ConvertF8SR
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
{
// check Y datatype
// check Y datatype
static_assert
(
is_same
<
Y
,
f8_t
>::
value
,
"Data type is not supported by this operation!"
);
static_assert
(
is_same
<
Y
,
f8_t
>::
value
||
is_same
<
Y
,
bf8_t
>::
value
,
"Data type is not supported by this operation!"
);
// check X datatype
// check X datatype
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
half_t
>::
value
,
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
half_t
>::
value
,
...
@@ -221,7 +217,6 @@ struct ConvertF8SR
...
@@ -221,7 +217,6 @@ struct ConvertF8SR
y
=
f8_convert_sr
<
Y
>
(
x
);
y
=
f8_convert_sr
<
Y
>
(
x
);
}
}
};
};
#endif
struct
Scale
struct
Scale
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment