Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5e206700
Commit
5e206700
authored
Oct 11, 2024
by
Astha Rai
Browse files
temp fix for namespace error in MIOpen
parent
251ab612
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
74 additions
and
74 deletions
+74
-74
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp
...ation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp
+1
-1
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+69
-69
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+2
-2
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+2
-2
No files found.
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp
View file @
5e206700
...
...
@@ -163,8 +163,8 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
const
CDEElementwiseOperation
&
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
#endif
};
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
5e206700
...
...
@@ -17,7 +17,7 @@ struct PassThroughPack2
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
__host__
__device__
constexpr
void
operator
()(
ck
::
half2_t
&
y
,
const
ck
::
f8x2_t
&
x
)
const
__host__
__device__
constexpr
void
operator
()(
half2_t
&
y
,
const
f8x2_t
&
x
)
const
{
auto
t
=
type_convert
<
float2_t
>
(
x
);
y
=
type_convert
<
half2_t
>
(
t
);
...
...
@@ -220,7 +220,7 @@ struct PassThrough
template
<
>
__host__
__device__
void
operator
()
<
bf8_t
,
half_t
>
(
bf8_t
&
y
,
const
half_t
&
x
)
const
{
y
=
ck
::
type_convert
<
bf8_t
>
(
x
);
y
=
type_convert
<
bf8_t
>
(
x
);
}
};
...
...
@@ -293,21 +293,21 @@ struct Scale
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
y
=
ck
::
type_convert
<
Y
>
(
ck
::
type_convert
<
float
>
(
x
)
*
scale_
);
y
=
type_convert
<
Y
>
(
type_convert
<
float
>
(
x
)
*
scale_
);
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
ck
::
type_convert
<
half_t
>
(
scale_
)
*
x
;
y
=
type_convert
<
half_t
>
(
scale_
)
*
x
;
};
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
const
float
x_tmp
=
ck
::
type_convert
<
float
>
(
x
);
const
float
x_tmp
=
type_convert
<
float
>
(
x
);
const
float
y_tmp
=
scale_
*
x_tmp
;
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_tmp
);
y
=
type_convert
<
bhalf_t
>
(
y_tmp
);
};
template
<
>
...
...
@@ -325,7 +325,7 @@ struct Scale
template
<
>
__host__
__device__
void
operator
()
<
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
ck
::
type_convert
<
int8_t
>
(
scale_
*
ck
::
type_convert
<
float
>
(
x
));
y
=
type_convert
<
int8_t
>
(
scale_
*
type_convert
<
float
>
(
x
));
};
float
scale_
;
...
...
@@ -341,7 +341,7 @@ struct ScaleAndResetNaNToMinusInfinity
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
ck
::
math
::
isnan
(
x
)
?
-
ck
::
NumericLimits
<
float
>::
Infinity
()
:
scale_
*
x
;
y
=
math
::
isnan
(
x
)
?
-
NumericLimits
<
float
>::
Infinity
()
:
scale_
*
x
;
};
float
scale_
;
...
...
@@ -417,7 +417,7 @@ struct UnaryAbs
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
abs
(
x
);
y
=
math
::
abs
(
x
);
};
};
...
...
@@ -429,7 +429,7 @@ struct UnarySqrt
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
sqrt
(
x
);
y
=
math
::
sqrt
(
x
);
};
};
...
...
@@ -448,9 +448,9 @@ struct Relu
template
<
>
__host__
__device__
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
float
x_f32
=
ck
::
type_convert
<
float
>
(
x
);
float
x_f32
=
type_convert
<
float
>
(
x
);
float
y_f32
=
x_f32
>
0
?
x_f32
:
0
;
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_f32
);
y
=
type_convert
<
bhalf_t
>
(
y_f32
);
}
};
...
...
@@ -466,7 +466,7 @@ struct FastGelu
template
<
typename
Y
,
typename
X
>
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
#ifndef CK_CODE_GEN_RTC
template
<
>
__host__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
...
...
@@ -477,6 +477,7 @@ struct FastGelu
const
float
emu
=
exp
(
u
);
y
=
x
/
(
1.
f
+
emu
);
}
#endif
// device code, use lower precision "__ocml_exp_f32" and "rcp"
template
<
>
...
...
@@ -488,7 +489,7 @@ struct FastGelu
const
float
u
=
x
*
(
c1
*
x
*
x
+
c2
);
const
float
emu
=
__ocml_exp_f32
(
u
);
y
=
x
*
ck
::
math
::
rcp
(
1.
f
+
emu
);
y
=
x
*
math
::
rcp
(
1.
f
+
emu
);
}
template
<
>
...
...
@@ -586,10 +587,9 @@ struct Gelu
}
template
<
>
__host__
__device__
void
operator
()
<
ck
::
half_t
,
ck
::
half_t
>
(
ck
::
half_t
&
y
,
const
ck
::
half_t
&
x
)
const
__host__
__device__
void
operator
()
<
half_t
,
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
ck
::
half_t
(
0.5
)
*
x
*
(
ck
::
half_t
(
1
)
+
ck
::
half_t
(
erf
(
float
(
0.70710678118
f
*
x
))));
y
=
half_t
(
0.5
)
*
x
*
(
half_t
(
1
)
+
half_t
(
erf
(
float
(
0.70710678118
f
*
x
))));
}
};
...
...
@@ -599,11 +599,11 @@ struct Sigmoid
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
y
=
one
/
(
one
+
math
::
exp
(
-
x
));
};
};
...
...
@@ -612,11 +612,11 @@ struct Silu
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same_v
<
T
,
float
>
||
is_same_v
<
T
,
double
>
||
is_same_v
<
T
,
ck
::
half_t
>
||
static_assert
(
is_same_v
<
T
,
float
>
||
is_same_v
<
T
,
double
>
||
is_same_v
<
T
,
half_t
>
||
is_same_v
<
T
,
int8_t
>
||
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
x
*
(
one
/
(
one
+
ck
::
math
::
exp
(
-
x
)));
y
=
x
*
(
one
/
(
one
+
math
::
exp
(
-
x
)));
};
};
...
...
@@ -626,11 +626,11 @@ struct TanH
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
tanh
(
x
);
y
=
math
::
tanh
(
x
);
};
};
...
...
@@ -640,11 +640,11 @@ struct ACos
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
acos
(
x
);
y
=
math
::
acos
(
x
);
};
};
...
...
@@ -654,11 +654,11 @@ struct Neg
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
neg
(
x
);
y
=
math
::
neg
(
x
);
};
};
...
...
@@ -668,11 +668,11 @@ struct ATan
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
atan
(
x
);
y
=
math
::
atan
(
x
);
};
};
...
...
@@ -682,11 +682,11 @@ struct Sin
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
sin
(
x
);
y
=
math
::
sin
(
x
);
};
};
...
...
@@ -696,11 +696,11 @@ struct ASinH
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
asinh
(
x
);
y
=
math
::
asinh
(
x
);
};
};
...
...
@@ -710,11 +710,11 @@ struct Cos
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
cos
(
x
);
y
=
math
::
cos
(
x
);
};
};
...
...
@@ -724,11 +724,11 @@ struct ACosH
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
acosh
(
x
);
y
=
math
::
acosh
(
x
);
};
};
...
...
@@ -738,11 +738,11 @@ struct Tan
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
tan
(
x
);
y
=
math
::
tan
(
x
);
};
};
...
...
@@ -752,11 +752,11 @@ struct ATanH
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
atanh
(
x
);
y
=
math
::
atanh
(
x
);
};
};
...
...
@@ -766,11 +766,11 @@ struct SinH
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
sinh
(
x
);
y
=
math
::
sinh
(
x
);
};
};
...
...
@@ -780,11 +780,11 @@ struct Ceil
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
ceil
(
x
);
y
=
math
::
ceil
(
x
);
};
};
...
...
@@ -794,11 +794,11 @@ struct Exp
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
exp
(
x
);
y
=
math
::
exp
(
x
);
};
};
...
...
@@ -808,11 +808,11 @@ struct CosH
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
cosh
(
x
);
y
=
math
::
cosh
(
x
);
};
};
...
...
@@ -822,11 +822,11 @@ struct Floor
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
floor
(
x
);
y
=
math
::
floor
(
x
);
};
};
...
...
@@ -836,11 +836,11 @@ struct Log
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
log
(
x
);
y
=
math
::
log
(
x
);
};
};
...
...
@@ -850,11 +850,11 @@ struct ASin
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
asin
(
x
);
y
=
math
::
asin
(
x
);
};
};
...
...
@@ -864,11 +864,11 @@ struct Rcp
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
rcp
(
x
);
y
=
math
::
rcp
(
x
);
};
};
...
...
@@ -880,15 +880,15 @@ struct Swish
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
double
>::
value
||
is_same
<
X
,
ck
::
half_t
>::
value
,
is_same
<
X
,
half_t
>::
value
,
"Data type is not supported by this operation!"
);
static_assert
(
is_same
<
Y
,
float
>::
value
||
is_same
<
Y
,
double
>::
value
||
is_same
<
Y
,
ck
::
half_t
>::
value
,
is_same
<
Y
,
half_t
>::
value
,
"Data type is not supported by this operation!"
);
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
math
::
exp
(
bx
)));
};
const
float
beta_
;
...
...
@@ -907,7 +907,7 @@ struct SoftRelu
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
y
=
math
::
log
(
one
+
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
const
float
alpha_
;
};
...
...
@@ -928,7 +928,7 @@ struct Power
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
T
casted_gamma
=
type_convert
<
T
>
(
gamma_
);
T
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
y
=
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
const
float
alpha_
;
const
float
beta_
;
...
...
@@ -948,7 +948,7 @@ struct ClippedRelu
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
y
=
math
::
min
(
casted_beta
,
math
::
max
(
casted_alpha
,
x
));
}
const
float
alpha_
;
const
float
beta_
;
...
...
@@ -983,7 +983,7 @@ struct Elu
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
y
=
x
>
0
?
x
:
casted_alpha
*
math
::
expm1
(
x
);
}
const
float
alpha_
;
};
...
...
@@ -1085,10 +1085,10 @@ struct FastNumericArrayConverter
};
template
<
>
struct
FastNumericArrayConverter
<
uint8_t
,
ck
::
half_t
,
4
>
struct
FastNumericArrayConverter
<
uint8_t
,
half_t
,
4
>
{
using
InputArray
=
vector_type
<
uint8_t
,
4
>
;
using
OutputArray
=
vector_type
<
ck
::
half_t
,
4
>
;
using
OutputArray
=
vector_type
<
half_t
,
4
>
;
__device__
static
OutputArray
convert
(
InputArray
const
&
Input
)
{
...
...
@@ -1118,13 +1118,13 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, 4>
};
template
<
index_t
N
>
struct
FastNumericArrayConverter
<
uint8_t
,
ck
::
half_t
,
N
>
struct
FastNumericArrayConverter
<
uint8_t
,
half_t
,
N
>
{
static
constexpr
int
VEC_WIDTH
=
4
;
static_assert
(
!
(
N
%
VEC_WIDTH
),
"N must be multiple of 4."
);
using
InputArray
=
vector_type
<
uint8_t
,
N
>
;
using
OutputArray
=
vector_type
<
ck
::
half_t
,
N
>
;
using
OutputArray
=
vector_type
<
half_t
,
N
>
;
__device__
static
OutputArray
convert
(
InputArray
const
&
Input
)
{
...
...
@@ -1133,7 +1133,7 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N>
OutputArray
Output
;
using
Vec_InputArray
=
vector_type
<
uint8_t
,
4
>
;
using
Vec_OutputArray
=
vector_type
<
ck
::
half_t
,
4
>
;
using
Vec_OutputArray
=
vector_type
<
half_t
,
4
>
;
Vec_OutputArray
*
half_4_ptr
=
reinterpret_cast
<
Vec_OutputArray
*>
(
&
Output
);
Vec_InputArray
const
*
uint8_4_ptr
=
reinterpret_cast
<
Vec_InputArray
const
*>
(
&
Input
);
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
5e206700
...
...
@@ -981,7 +981,7 @@ struct BlockToCTileMap_3DGrid_KSplit
// Create 3D grid
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
return
ck
::
make_tuple
(
N0
,
M0
,
k_split
);
return
make_tuple
(
N0
,
M0
,
k_split
);
}
template
<
typename
TopIdx
>
...
...
@@ -1105,7 +1105,7 @@ struct BlockToCTileMap_GemmStreamK
uint32_t
dp_for_sk_iters
=
k_iters_per_tile
.
get
();
uint32_t
best_sk_score
=
ck
::
NumericLimits
<
int32_t
>::
Max
();
// we need to find the smallest sk iters
NumericLimits
<
int32_t
>::
Max
();
// we need to find the smallest sk iters
for
(
uint32_t
tentative_sk_blocks
=
min_sk_tiles
;
tentative_sk_blocks
<
max_sk_tiles
;
tentative_sk_blocks
++
)
{
...
...
include/ck/utility/data_type.hpp
View file @
5e206700
...
...
@@ -1075,10 +1075,10 @@ using uint8x64_t = typename vector_type<uint8_t, 64>::type;
template
<
typename
T
>
struct
NumericLimits
;
#ifndef CK_CODE_GEN_RTC
template
<
typename
T
>
struct
NumericLimits
{
#ifndef CK_CODE_GEN_RTC
__host__
__device__
static
constexpr
T
Min
()
{
return
std
::
numeric_limits
<
T
>::
min
();
}
__host__
__device__
static
constexpr
T
Max
()
{
return
std
::
numeric_limits
<
T
>::
max
();
}
__host__
__device__
static
constexpr
T
Lowest
()
{
return
std
::
numeric_limits
<
T
>::
lowest
();
}
...
...
@@ -1087,8 +1087,8 @@ struct NumericLimits
return
std
::
numeric_limits
<
T
>::
quiet_NaN
();
}
__host__
__device__
static
constexpr
T
Infinity
()
{
return
std
::
numeric_limits
<
T
>::
infinity
();
}
#endif
};
#endif
template
<
>
struct
NumericLimits
<
int32_t
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment