Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d480a5a6
Unverified
Commit
d480a5a6
authored
Feb 03, 2025
by
Max Podkorytov
Committed by
GitHub
Feb 03, 2025
Browse files
Merge branch 'develop' into ck-flex
parents
bca939ce
9c5b2f39
Changes
94
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
537 additions
and
207 deletions
+537
-207
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
.../device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
+1
-0
include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp
...operation/gpu/device/impl/device_image_to_column_impl.hpp
+1
-0
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
+2
-0
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+3
-3
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+2
-2
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+66
-65
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+6
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
+40
-13
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+13
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
...or_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
+7
-2
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+2
-2
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
...eration/operator_transform/transform_conv_fwd_to_gemm.hpp
+74
-63
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+15
-1
include/ck/utility/amd_ck_fp8.hpp
include/ck/utility/amd_ck_fp8.hpp
+15
-5
include/ck/utility/amd_wave_read_first_lane.hpp
include/ck/utility/amd_wave_read_first_lane.hpp
+14
-13
include/ck/utility/array.hpp
include/ck/utility/array.hpp
+4
-2
include/ck/utility/container_helper.hpp
include/ck/utility/container_helper.hpp
+3
-3
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+249
-26
include/ck/utility/debug.hpp
include/ck/utility/debug.hpp
+2
-1
include/ck/utility/enable_if.hpp
include/ck/utility/enable_if.hpp
+18
-1
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
View file @
d480a5a6
...
...
@@ -9,6 +9,7 @@
#include <numeric>
#include <sstream>
#include "ck/library/utility/numeric.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
...
...
include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp
View file @
d480a5a6
...
...
@@ -3,6 +3,7 @@
#pragma once
#include "ck/library/utility/numeric.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp"
...
...
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
View file @
d480a5a6
...
...
@@ -430,6 +430,7 @@ struct G_NDHW : public BaseTensorLayout
}
// namespace convolution
#ifndef CK_CODE_GEN_RTC
template
<
typename
Layout
,
typename
std
::
enable_if
<
std
::
is_base_of
<
BaseTensorLayout
,
Layout
>
::
value
,
bool
>::
type
=
false
>
...
...
@@ -438,6 +439,7 @@ std::ostream& operator<<(std::ostream& os, const Layout&)
os
<<
Layout
::
name
;
return
os
;
}
#endif
}
// namespace tensor_layout
}
// namespace ck
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -340,8 +340,8 @@ struct Bilinear
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
std
::
int8_t
,
std
::
int32_t
,
std
::
int8_t
>
(
std
::
int8_t
&
y
,
const
std
::
int32_t
&
x0
,
const
std
::
int8_t
&
x1
)
const
__host__
__device__
constexpr
void
operator
()
<
int8_t
,
int32_t
,
int8_t
>
(
int8_t
&
y
,
const
int32_t
&
x0
,
const
int8_t
&
x1
)
const
{
y
=
type_convert
<
int8_t
>
(
alpha_
*
type_convert
<
float
>
(
x0
)
+
beta_
*
type_convert
<
float
>
(
x1
));
...
...
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -533,7 +533,7 @@ struct NormalizeInInfer
const
T3
&
gamma
,
const
T4
&
beta
)
const
{
static_assert
(
std
::
is_same
<
T2
,
float
>::
value
||
std
::
is_same
<
T2
,
double
>::
value
,
static_assert
(
is_same
<
T2
,
float
>::
value
||
is_same
<
T2
,
double
>::
value
,
"Data type is not supported by this operation!"
);
using
ck
::
type_convert
;
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
d480a5a6
...
...
@@ -252,7 +252,7 @@ struct PassThroughPack2
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
__host__
__device__
constexpr
void
operator
()(
ck
::
half2_t
&
y
,
const
ck
::
f8x2_t
&
x
)
const
__host__
__device__
constexpr
void
operator
()(
half2_t
&
y
,
const
f8x2_t
&
x
)
const
{
auto
t
=
type_convert
<
float2_t
>
(
x
);
y
=
type_convert
<
half2_t
>
(
t
);
...
...
@@ -479,7 +479,7 @@ struct PassThrough
template
<
>
__host__
__device__
void
operator
()
<
bf8_t
,
half_t
>
(
bf8_t
&
y
,
const
half_t
&
x
)
const
{
y
=
ck
::
type_convert
<
bf8_t
>
(
x
);
y
=
type_convert
<
bf8_t
>
(
x
);
}
};
...
...
@@ -552,21 +552,21 @@ struct Scale
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
y
=
ck
::
type_convert
<
Y
>
(
ck
::
type_convert
<
float
>
(
x
)
*
scale_
);
y
=
type_convert
<
Y
>
(
type_convert
<
float
>
(
x
)
*
scale_
);
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
ck
::
type_convert
<
half_t
>
(
scale_
)
*
x
;
y
=
type_convert
<
half_t
>
(
scale_
)
*
x
;
};
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
const
float
x_tmp
=
ck
::
type_convert
<
float
>
(
x
);
const
float
x_tmp
=
type_convert
<
float
>
(
x
);
const
float
y_tmp
=
scale_
*
x_tmp
;
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_tmp
);
y
=
type_convert
<
bhalf_t
>
(
y_tmp
);
};
template
<
>
...
...
@@ -584,7 +584,7 @@ struct Scale
template
<
>
__host__
__device__
void
operator
()
<
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
ck
::
type_convert
<
int8_t
>
(
scale_
*
ck
::
type_convert
<
float
>
(
x
));
y
=
type_convert
<
int8_t
>
(
scale_
*
type_convert
<
float
>
(
x
));
};
float
scale_
;
...
...
@@ -600,7 +600,7 @@ struct ScaleAndResetNaNToMinusInfinity
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
ck
::
math
::
isnan
(
x
)
?
-
ck
::
NumericLimits
<
float
>::
Infinity
()
:
scale_
*
x
;
y
=
math
::
isnan
(
x
)
?
-
NumericLimits
<
float
>::
Infinity
()
:
scale_
*
x
;
};
float
scale_
;
...
...
@@ -671,12 +671,13 @@ struct UnaryAbs
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
abs
(
x
);
y
=
math
::
abs
(
x
);
};
template
<
>
...
...
@@ -694,7 +695,7 @@ struct UnarySqrt
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
sqrt
(
x
);
y
=
math
::
sqrt
(
x
);
};
};
...
...
@@ -713,9 +714,9 @@ struct Relu
template
<
>
__host__
__device__
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
float
x_f32
=
ck
::
type_convert
<
float
>
(
x
);
float
x_f32
=
type_convert
<
float
>
(
x
);
float
y_f32
=
x_f32
>
0
?
x_f32
:
0
;
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_f32
);
y
=
type_convert
<
bhalf_t
>
(
y_f32
);
}
};
...
...
@@ -731,7 +732,7 @@ struct FastGelu
template
<
typename
Y
,
typename
X
>
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
#ifndef CK_CODE_GEN_RTC
template
<
>
__host__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
...
...
@@ -742,6 +743,7 @@ struct FastGelu
const
float
emu
=
exp
(
u
);
y
=
x
/
(
1.
f
+
emu
);
}
#endif
// device code, use lower precision "__ocml_exp_f32" and "rcp"
template
<
>
...
...
@@ -753,7 +755,7 @@ struct FastGelu
const
float
u
=
x
*
(
c1
*
x
*
x
+
c2
);
const
float
emu
=
__ocml_exp_f32
(
u
);
y
=
x
*
ck
::
math
::
rcp
(
1.
f
+
emu
);
y
=
x
*
math
::
rcp
(
1.
f
+
emu
);
}
template
<
>
...
...
@@ -851,10 +853,9 @@ struct Gelu
}
template
<
>
__host__
__device__
void
operator
()
<
ck
::
half_t
,
ck
::
half_t
>
(
ck
::
half_t
&
y
,
const
ck
::
half_t
&
x
)
const
__host__
__device__
void
operator
()
<
half_t
,
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
ck
::
half_t
(
0.5
)
*
x
*
(
ck
::
half_t
(
1
)
+
ck
::
half_t
(
erf
(
float
(
0.70710678118
f
*
x
))));
y
=
half_t
(
0.5
)
*
x
*
(
half_t
(
1
)
+
half_t
(
erf
(
float
(
0.70710678118
f
*
x
))));
}
};
...
...
@@ -868,7 +869,7 @@ struct Sigmoid
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
y
=
one
/
(
one
+
math
::
exp
(
-
x
));
};
};
...
...
@@ -877,11 +878,11 @@ struct Silu
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same_v
<
T
,
float
>
||
is_same_v
<
T
,
double
>
||
is_same_v
<
T
,
ck
::
half_t
>
||
static_assert
(
is_same_v
<
T
,
float
>
||
is_same_v
<
T
,
double
>
||
is_same_v
<
T
,
half_t
>
||
is_same_v
<
T
,
int8_t
>
||
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
x
*
(
one
/
(
one
+
ck
::
math
::
exp
(
-
x
)));
y
=
x
*
(
one
/
(
one
+
math
::
exp
(
-
x
)));
};
};
...
...
@@ -895,7 +896,7 @@ struct TanH
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
tanh
(
x
);
y
=
math
::
tanh
(
x
);
};
};
...
...
@@ -905,11 +906,11 @@ struct ACos
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
acos
(
x
);
y
=
math
::
acos
(
x
);
};
};
...
...
@@ -919,11 +920,11 @@ struct Neg
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
neg
(
x
);
y
=
math
::
neg
(
x
);
};
};
...
...
@@ -933,11 +934,11 @@ struct ATan
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
atan
(
x
);
y
=
math
::
atan
(
x
);
};
};
...
...
@@ -947,11 +948,11 @@ struct Sin
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
sin
(
x
);
y
=
math
::
sin
(
x
);
};
};
...
...
@@ -961,11 +962,11 @@ struct ASinH
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
asinh
(
x
);
y
=
math
::
asinh
(
x
);
};
};
...
...
@@ -975,11 +976,11 @@ struct Cos
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
cos
(
x
);
y
=
cos
(
x
);
};
};
...
...
@@ -989,11 +990,11 @@ struct ACosH
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
acosh
(
x
);
y
=
math
::
acosh
(
x
);
};
};
...
...
@@ -1003,11 +1004,11 @@ struct Tan
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
tan
(
x
);
y
=
math
::
tan
(
x
);
};
};
...
...
@@ -1017,11 +1018,11 @@ struct ATanH
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
atanh
(
x
);
y
=
math
::
atanh
(
x
);
};
};
...
...
@@ -1031,11 +1032,11 @@ struct SinH
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
sinh
(
x
);
y
=
math
::
sinh
(
x
);
};
};
...
...
@@ -1045,11 +1046,11 @@ struct Ceil
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
ceil
(
x
);
y
=
math
::
ceil
(
x
);
};
};
...
...
@@ -1059,11 +1060,11 @@ struct Exp
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
exp
(
x
);
y
=
math
::
exp
(
x
);
};
};
...
...
@@ -1073,11 +1074,11 @@ struct CosH
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
cosh
(
x
);
y
=
math
::
cosh
(
x
);
};
};
...
...
@@ -1087,11 +1088,11 @@ struct Floor
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
floor
(
x
);
y
=
math
::
floor
(
x
);
};
};
...
...
@@ -1101,11 +1102,11 @@ struct Log
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
log
(
x
);
y
=
math
::
log
(
x
);
};
};
...
...
@@ -1115,11 +1116,11 @@ struct ASin
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
asin
(
x
);
y
=
math
::
asin
(
x
);
};
};
...
...
@@ -1129,11 +1130,11 @@ struct Rcp
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
rcp
(
x
);
y
=
math
::
rcp
(
x
);
};
};
...
...
@@ -1153,7 +1154,7 @@ struct Swish
"Data type is not supported by this operation!"
);
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
math
::
exp
(
bx
)));
};
const
float
beta_
;
...
...
@@ -1172,7 +1173,7 @@ struct SoftRelu
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
y
=
math
::
log
(
one
+
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
const
float
alpha_
;
};
...
...
@@ -1193,7 +1194,7 @@ struct Power
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
T
casted_gamma
=
type_convert
<
T
>
(
gamma_
);
T
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
y
=
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
const
float
alpha_
;
const
float
beta_
;
...
...
@@ -1213,7 +1214,7 @@ struct ClippedRelu
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
y
=
math
::
min
(
casted_beta
,
math
::
max
(
casted_alpha
,
x
));
}
const
float
alpha_
;
const
float
beta_
;
...
...
@@ -1248,7 +1249,7 @@ struct Elu
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
y
=
x
>
0
?
x
:
casted_alpha
*
math
::
expm1
(
x
);
}
const
float
alpha_
;
};
...
...
@@ -1350,10 +1351,10 @@ struct FastNumericArrayConverter
};
template
<
>
struct
FastNumericArrayConverter
<
uint8_t
,
ck
::
half_t
,
4
>
struct
FastNumericArrayConverter
<
uint8_t
,
half_t
,
4
>
{
using
InputArray
=
vector_type
<
uint8_t
,
4
>
;
using
OutputArray
=
vector_type
<
ck
::
half_t
,
4
>
;
using
OutputArray
=
vector_type
<
half_t
,
4
>
;
__device__
static
OutputArray
convert
(
InputArray
const
&
Input
)
{
...
...
@@ -1383,13 +1384,13 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, 4>
};
template
<
index_t
N
>
struct
FastNumericArrayConverter
<
uint8_t
,
ck
::
half_t
,
N
>
struct
FastNumericArrayConverter
<
uint8_t
,
half_t
,
N
>
{
static
constexpr
int
VEC_WIDTH
=
4
;
static_assert
(
!
(
N
%
VEC_WIDTH
),
"N must be multiple of 4."
);
using
InputArray
=
vector_type
<
uint8_t
,
N
>
;
using
OutputArray
=
vector_type
<
ck
::
half_t
,
N
>
;
using
OutputArray
=
vector_type
<
half_t
,
N
>
;
__device__
static
OutputArray
convert
(
InputArray
const
&
Input
)
{
...
...
@@ -1398,7 +1399,7 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N>
OutputArray
Output
;
using
Vec_InputArray
=
vector_type
<
uint8_t
,
4
>
;
using
Vec_OutputArray
=
vector_type
<
ck
::
half_t
,
4
>
;
using
Vec_OutputArray
=
vector_type
<
half_t
,
4
>
;
Vec_OutputArray
*
half_4_ptr
=
reinterpret_cast
<
Vec_OutputArray
*>
(
&
Output
);
Vec_InputArray
const
*
uint8_4_ptr
=
reinterpret_cast
<
Vec_InputArray
const
*>
(
&
Input
);
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/math.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#ifndef CK_CODE_GEN_RTC
#include <limits>
#include <stdlib.h>
#endif
namespace
ck
{
...
...
@@ -978,8 +981,7 @@ struct BlockToCTileMap_3DGrid_KSplit
// Create 3D grid
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
return
std
::
make_tuple
(
N0
,
M0
,
k_split
);
return
make_tuple
(
N0
,
M0
,
k_split
);
}
template
<
typename
TopIdx
>
...
...
@@ -1103,7 +1105,7 @@ struct BlockToCTileMap_GemmStreamK
uint32_t
dp_for_sk_iters
=
k_iters_per_tile
.
get
();
uint32_t
best_sk_score
=
std
::
n
umeric
_l
imits
<
int
>::
m
ax
();
// we need to find the smallest sk iters
N
umeric
L
imits
<
int
32_t
>::
M
ax
();
// we need to find the smallest sk iters
for
(
uint32_t
tentative_sk_blocks
=
min_sk_tiles
;
tentative_sk_blocks
<
max_sk_tiles
;
tentative_sk_blocks
++
)
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -423,10 +423,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
}
template
<
typename
AsLayout
,
GemmSpecialization
GemmSpec
>
__host__
__device__
static
auto
MakeAsGridDescriptor_M_K
(
const
std
::
array
<
index_t
,
NumATensor
>&
MRaws
,
const
std
::
array
<
index_t
,
NumATensor
>&
KRaws
,
const
std
::
array
<
index_t
,
NumATensor
>&
AsStride
)
__host__
__device__
static
auto
MakeAsGridDescriptor_M_K
(
#ifdef CK_CODE_GEN_RTC
const
ck
::
Array
<
index_t
,
NumATensor
>&
MRaws
,
const
ck
::
Array
<
index_t
,
NumATensor
>&
KRaws
,
const
ck
::
Array
<
index_t
,
NumATensor
>&
AsStride
#else
const
std
::
array
<
index_t
,
NumATensor
>&
MRaws
,
const
std
::
array
<
index_t
,
NumATensor
>&
KRaws
,
const
std
::
array
<
index_t
,
NumATensor
>&
AsStride
#endif
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
...
...
@@ -462,10 +469,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
}
template
<
typename
BsLayout
,
GemmSpecialization
GemmSpec
>
__host__
__device__
static
auto
MakeBsGridDescriptor_N_K
(
const
std
::
array
<
index_t
,
NumBTensor
>&
NRaws
,
const
std
::
array
<
index_t
,
NumBTensor
>&
KRaws
,
const
std
::
array
<
index_t
,
NumBTensor
>&
BsStride
)
__host__
__device__
static
auto
MakeBsGridDescriptor_N_K
(
#ifdef CK_CODE_GEN_RTC
const
ck
::
Array
<
index_t
,
NumBTensor
>&
NRaws
,
const
ck
::
Array
<
index_t
,
NumBTensor
>&
KRaws
,
const
ck
::
Array
<
index_t
,
NumBTensor
>&
BsStride
#else
const
std
::
array
<
index_t
,
NumBTensor
>&
NRaws
,
const
std
::
array
<
index_t
,
NumBTensor
>&
KRaws
,
const
std
::
array
<
index_t
,
NumBTensor
>&
BsStride
#endif
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
...
...
@@ -500,10 +514,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
}
template
<
typename
DsLayout
,
GemmSpecialization
GemmSpec
>
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NumDTensor
>&
MRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
NRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
DsStride
)
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
#ifdef CK_CODE_GEN_RTC
const
ck
::
Array
<
index_t
,
NumDTensor
>&
MRaws
,
const
ck
::
Array
<
index_t
,
NumDTensor
>&
NRaws
,
const
ck
::
Array
<
index_t
,
NumDTensor
>&
DsStride
#else
const
std
::
array
<
index_t
,
NumDTensor
>&
MRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
NRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
DsStride
#endif
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
...
...
@@ -969,9 +990,15 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
const
index_t
M
,
const
index_t
N
,
const
index_t
K
,
#ifdef CK_CODE_GEN_RTC
const
ck
::
Array
<
index_t
,
NumATensor
>
StrideAs
,
const
ck
::
Array
<
index_t
,
NumBTensor
>
StrideBs
,
const
ck
::
Array
<
index_t
,
NumDTensor
>
StrideDs
,
#else
const
std
::
array
<
index_t
,
NumATensor
>
StrideAs
,
const
std
::
array
<
index_t
,
NumBTensor
>
StrideBs
,
const
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
#endif
const
index_t
StrideE
,
const
Block2ETileMap
&
block_2_etile_map
)
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -473,11 +473,19 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
#ifdef CK_CODE_GEN_RTC
template
<
typename
DsLayout
,
GemmSpecialization
GemmSpec
>
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
const
ck
::
Array
<
index_t
,
NumDTensor
>&
MRaws
,
const
ck
::
Array
<
index_t
,
NumDTensor
>&
NRaws
,
const
ck
::
Array
<
index_t
,
NumDTensor
>&
DsStride
)
#else
template
<
typename
DsLayout
,
GemmSpecialization
GemmSpec
>
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NumDTensor
>&
MRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
NRaws
,
const
std
::
array
<
index_t
,
NumDTensor
>&
DsStride
)
#endif
{
return
generate_tuple
(
[
&
](
auto
i
)
{
...
...
@@ -941,7 +949,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const
index_t
K
,
const
index_t
StrideA
,
const
index_t
StrideB
,
#ifdef CK_CODE_GEN_RTC
const
ck
::
Array
<
index_t
,
NumDTensor
>
StrideDs
,
#else
const
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
#endif
const
index_t
StrideE
,
const
Block2ETileMap
&
block_2_etile_map
)
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef CK_CODE_GEN_RTC
#include <iostream>
#include <ostream>
#endif
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
...
...
@@ -53,12 +54,15 @@ constexpr auto GridwiseGemmPipeline_Selector()
}
else
{
#ifndef CK_CODE_GEN_RTC
std
::
cerr
<<
"GridwiseGemmPipeline configuration is not available"
<<
std
::
endl
;
#endif
}
}
}
// namespace ck
#ifndef CK_CODE_GEN_RTC
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ck
::
PipelineVersion
&
p
)
{
switch
(
p
)
...
...
@@ -71,3 +75,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p)
}
return
os
;
}
#endif
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -212,7 +212,7 @@ template <typename SrcData,
typename
enable_if
<
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseTensorSliceTransfer_v2
{
static_assert
((
InvalidElementAsNaN
&&
!
std
::
is_integral
<
DstData
>::
value
)
||
static_assert
((
InvalidElementAsNaN
&&
!
ck
::
is_integral
<
DstData
>::
value
)
||
(
!
InvalidElementAsNaN
),
"Filling invalid element as NaN is only for floating point types"
);
...
...
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/library/utility/numeric.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
...
...
@@ -148,8 +147,8 @@ struct TransformConvFwdToGemm
template
<
typename
ConvDimsType
,
typename
ConvSpatialDimsType
,
index_t
NDim
=
NDimSpatial
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
index_t
NDim
=
NDimSpatial
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
__host__
__device__
TransformConvFwdToGemm
(
const
ConvDimsType
&
a_g_n_c_wis_lengths
,
const
ConvDimsType
&
a_g_n_c_wis_strides
,
const
ConvDimsType
&
b_g_k_c_xs_lengths
,
...
...
@@ -201,11 +200,15 @@ struct TransformConvFwdToGemm
InRightPadW_
{
input_right_pads
[
I0
]},
ZYX_
{
X_
}
{
#ifdef CK_CODE_GEN_RTC
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
+
I3
>>
);
#else
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
std
::
array
<
IndexType
,
NDimSpatial
>>
||
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
std
::
array
<
IndexType
,
NDimSpatial
+
I3
>>
||
is_same_v
<
ConvDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
+
I3
>>
);
#endif
if
constexpr
(
SplitN
)
{
N_
=
GetSplitedNSize
(
...
...
@@ -219,8 +222,8 @@ struct TransformConvFwdToGemm
template
<
typename
ConvDimsType
,
typename
ConvSpatialDimsType
,
index_t
NDim
=
NDimSpatial
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
index_t
NDim
=
NDimSpatial
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
__host__
__device__
TransformConvFwdToGemm
(
const
ConvDimsType
&
a_g_n_c_wis_lengths
,
const
ConvDimsType
&
a_g_n_c_wis_strides
,
const
ConvDimsType
&
b_g_k_c_xs_lengths
,
...
...
@@ -272,11 +275,15 @@ struct TransformConvFwdToGemm
InRightPadW_
{
input_right_pads
[
I1
]},
ZYX_
{
Y_
*
X_
}
{
#ifdef CK_CODE_GEN_RTC
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
+
I3
>>
);
#else
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
std
::
array
<
IndexType
,
NDimSpatial
>>
||
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
std
::
array
<
IndexType
,
NDimSpatial
+
I3
>>
||
is_same_v
<
ConvDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
+
I3
>>
);
#endif
if
constexpr
(
SplitN
)
{
N_
=
GetSplitedNSize
(
...
...
@@ -290,8 +297,8 @@ struct TransformConvFwdToGemm
template
<
typename
ConvDimsType
,
typename
ConvSpatialDimsType
,
index_t
NDim
=
NDimSpatial
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
index_t
NDim
=
NDimSpatial
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
__host__
__device__
TransformConvFwdToGemm
(
const
ConvDimsType
&
a_g_n_c_wis_lengths
,
const
ConvDimsType
&
a_g_n_c_wis_strides
,
const
ConvDimsType
&
b_g_k_c_xs_lengths
,
...
...
@@ -343,11 +350,15 @@ struct TransformConvFwdToGemm
InRightPadW_
{
input_right_pads
[
I2
]},
ZYX_
{
Z_
*
Y_
*
X_
}
{
#ifdef CK_CODE_GEN_RTC
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
+
I3
>>
);
#else
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
std
::
array
<
IndexType
,
NDimSpatial
>>
||
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
std
::
array
<
IndexType
,
NDimSpatial
+
I3
>>
||
is_same_v
<
ConvDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
+
I3
>>
);
#endif
if
constexpr
(
SplitN
)
{
N_
=
GetSplitedNSize
(
...
...
@@ -478,11 +489,11 @@ struct TransformConvFwdToGemm
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
// properties
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
1
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNWC
>
),
bool
>::
type
=
false
>
typename
ck
::
enable_if
<
NDimSpatial
==
1
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNWC
>
),
bool
>::
type
=
false
>
__host__
__device__
auto
MakeADescriptor_M_K
()
const
{
if
constexpr
(
ConvForwardSpecialization
==
...
...
@@ -691,11 +702,11 @@ struct TransformConvFwdToGemm
}
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
2
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NHW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWC
>
),
bool
>::
type
=
false
>
typename
ck
::
enable_if
<
NDimSpatial
==
2
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NHW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWC
>
),
bool
>::
type
=
false
>
__host__
__device__
auto
MakeADescriptor_M_K
()
const
{
...
...
@@ -932,7 +943,7 @@ struct TransformConvFwdToGemm
}
template
<
typename
ALayout
,
typename
std
::
enable_if
<
typename
ck
::
enable_if
<
NDimSpatial
==
3
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NDHW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NDHWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNDHWC
>
),
...
...
@@ -1242,19 +1253,19 @@ struct TransformConvFwdToGemm
}
template
<
typename
BLayout
,
typename
std
::
enable_if
<
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKXC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKYXC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
,
bool
>::
type
=
false
>
typename
ck
::
enable_if
<
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKXC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKYXC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
,
bool
>::
type
=
false
>
__host__
__device__
auto
MakeBDescriptor_N_K
()
const
{
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter3x3
)
{
using
FilterSizeNumType
=
std
::
conditional_t
<
NDimSpatial
==
1
,
Number
<
3
>
,
std
::
conditional_t
<
NDimSpatial
==
2
,
Number
<
9
>
,
Number
<
27
>>>
;
ck
::
conditional_t
<
NDimSpatial
==
1
,
Number
<
3
>
,
ck
::
conditional_t
<
NDimSpatial
==
2
,
Number
<
9
>
,
Number
<
27
>>>
;
if
constexpr
(
NumGroupsToMerge
==
1
)
{
...
...
@@ -1297,13 +1308,13 @@ struct TransformConvFwdToGemm
template
<
typename
BLayout
,
typename
std
::
enable_if
<
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
G_K_X_C
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
G_K_YX_C
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
G_K_ZYX_C
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KXGC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KYXGC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KZYXGC
>
,
bool
>::
type
=
false
>
typename
ck
::
enable_if
<
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
G_K_X_C
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
G_K_YX_C
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
G_K_ZYX_C
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KXGC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KYXGC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KZYXGC
>
,
bool
>::
type
=
false
>
__host__
__device__
auto
MakeBDescriptor_N_K
()
const
{
const
auto
wei_k_yx_c_desc
=
make_naive_tensor_descriptor
(
...
...
@@ -1318,36 +1329,36 @@ struct TransformConvFwdToGemm
return
wei_gemmn_gemmk_desc
;
}
template
<
typename
CLayout
,
index_t
NDimSp
=
NDimSpatial
,
template
<
typename
CLayout
,
index_t
NDimSp
=
NDimSpatial
,
typename
std
::
enable_if
<
NDimSp
==
1
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>),
bool
>::
type
=
false
>
typename
ck
::
enable_if
<
NDimSp
==
1
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>),
bool
>::
type
=
false
>
__host__
__device__
auto
MakeCDescriptor_M_N
()
const
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N_
*
Wo_
,
K_
),
make_tuple
(
I0
,
KStrideTensorC_
));
}
template
<
typename
CLayout
,
index_t
NDimSp
=
NDimSpatial
,
template
<
typename
CLayout
,
index_t
NDimSp
=
NDimSpatial
,
typename
std
::
enable_if
<
NDimSp
==
2
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>),
bool
>::
type
=
false
>
typename
ck
::
enable_if
<
NDimSp
==
2
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>),
bool
>::
type
=
false
>
__host__
__device__
auto
MakeCDescriptor_M_N
()
const
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N_
*
Ho_
*
Wo_
,
K_
),
make_tuple
(
I0
,
KStrideTensorC_
));
}
template
<
typename
CLayout
,
index_t
NDimSp
=
NDimSpatial
,
template
<
typename
CLayout
,
index_t
NDimSp
=
NDimSpatial
,
typename
std
::
enable_if
<
NDimSp
==
3
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>),
bool
>::
type
=
false
>
typename
ck
::
enable_if
<
NDimSp
==
3
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>),
bool
>::
type
=
false
>
__host__
__device__
auto
MakeCDescriptor_M_N
()
const
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N_
*
Do_
*
Ho_
*
Wo_
,
K_
),
...
...
@@ -1355,12 +1366,12 @@ struct TransformConvFwdToGemm
}
template
<
typename
CLayout
,
index_t
NDimSp
=
NDimSpatial
,
typename
std
::
enable_if
<
NDimSp
==
1
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNWK
>
),
bool
>::
type
=
false
>
index_t
NDimSp
=
NDimSpatial
,
typename
ck
::
enable_if
<
NDimSp
==
1
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNWK
>
),
bool
>::
type
=
false
>
__host__
__device__
auto
MakeCDescriptor_M_N
()
const
{
const
IndexType
NDoHoWo
=
N_
*
Wo_
;
...
...
@@ -1410,11 +1421,11 @@ struct TransformConvFwdToGemm
template
<
typename
CLayout
,
index_t
NDimSp
=
NDimSpatial
,
typename
std
::
enable_if
<
NDimSp
==
2
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NHW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NHWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNHWK
>
),
bool
>::
type
=
false
>
typename
ck
::
enable_if
<
NDimSp
==
2
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NHW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NHWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNHWK
>
),
bool
>::
type
=
false
>
__host__
__device__
auto
MakeCDescriptor_M_N
()
const
{
const
IndexType
NDoHoWo
=
N_
*
Ho_
*
Wo_
;
...
...
@@ -1467,7 +1478,7 @@ struct TransformConvFwdToGemm
template
<
typename
CLayout
,
index_t
NDimSp
=
NDimSpatial
,
typename
std
::
enable_if
<
typename
ck
::
enable_if
<
NDimSp
==
3
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NDHW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NDHWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNDHWK
>
),
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "data_type.hpp"
...
...
@@ -1021,15 +1021,24 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
constexpr
auto
bytes_per_thread
=
sizeof
(
T
)
*
NumElemsPerThread
;
static_assert
(
bytes_per_thread
==
dword_bytes
);
#ifndef CK_CODE_GEN_RTC
const
uint32_t
*
global_ptr
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
uintptr_t
>
(
global_base_ptr
));
#else
const
uint32_t
*
global_ptr
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
size_t
>
(
global_base_ptr
));
#endif
const
int32x4_t
src_resource
=
make_wave_buffer_resource
(
global_ptr
,
src_element_space_size
);
const
index_t
global_offset_bytes
=
is_valid
?
global_offset
*
sizeof
(
T
)
:
0x80000000
;
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
T
*
lds_ptr
=
lds_base_ptr
+
lds_offset
;
#ifndef CK_CODE_GEN_RTC
auto
const
lds_ptr_sgpr
=
__builtin_amdgcn_readfirstlane
((
reinterpret_cast
<
uintptr_t
>
(
lds_ptr
)));
#else
auto
const
lds_ptr_sgpr
=
__builtin_amdgcn_readfirstlane
((
reinterpret_cast
<
size_t
>
(
lds_ptr
)));
#endif
asm
volatile
(
"s_mov_b32 m0, %0;
\n\t
"
"buffer_load_dword %1, %2, 0 offen lds;
\n\t
"
::
"s"
(
lds_ptr_sgpr
),
"v"
(
global_offset_bytes
),
...
...
@@ -1038,8 +1047,13 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
#else
// LDS pointer must be attributed with the LDS address space.
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
=
#ifndef CK_CODE_GEN_RTC
reinterpret_cast
<
__attribute__
((
address_space
(
3
)))
uint32_t
*>
(
reinterpret_cast
<
uintptr_t
>
(
lds_base_ptr
+
lds_offset
));
#else
reinterpret_cast
<
__attribute__
((
address_space
(
3
)))
uint32_t
*>
(
reinterpret_cast
<
size_t
>
(
lds_base_ptr
+
lds_offset
));
#endif
llvm_amdgcn_raw_buffer_load_lds
(
src_resource
,
lds_ptr
,
sizeof
(
uint32_t
),
global_offset_bytes
,
0
,
0
,
0
);
...
...
include/ck/utility/amd_ck_fp8.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/enable_if.hpp"
#include "ck/utility/random_gen.hpp"
#include "ck/utility/type.hpp"
...
...
@@ -424,9 +426,9 @@ __host__ __device__ inline constexpr bool fp8_is_nan(bf8_fnuz_t a)
}
template
<
typename
T
,
std
::
enable_if_t
<
std
::
is_same_v
<
T
,
bf8_ocp_t
>
||
std
::
is_same_v
<
T
,
f8_ocp_t
>
||
std
::
is_same_v
<
T
,
bf8_fnuz_t
>
||
std
::
is_same_v
<
T
,
f8_fnuz_t
>
,
bool
>
=
true
>
ck
::
enable_if_t
<
is_same_v
<
T
,
bf8_ocp_t
>
||
is_same_v
<
T
,
f8_ocp_t
>
||
is_same_v
<
T
,
bf8_fnuz_t
>
||
is_same_v
<
T
,
f8_fnuz_t
>
,
bool
>
=
true
>
__host__
__device__
static
inline
constexpr
bool
fp8_is_inf
(
T
)
{
return
false
;
...
...
@@ -823,7 +825,11 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
if
constexpr
(
stochastic_rounding
)
{
constexpr
int
seed
=
1254739
;
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
);
#ifndef CK_CODE_GEN_RTC
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
);
#else
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
f
),
f
);
#endif
}
return
cast_to_f8_from_f32
<
interp
,
sat
==
ck_saturation_t
::
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
...
...
@@ -839,7 +845,11 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
if
constexpr
(
stochastic_rounding
)
{
constexpr
int
seed
=
1254739
;
#ifndef CK_CODE_GEN_RTC
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
);
#else
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
f
),
f
);
#endif
}
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E4M3_FNUZ
)
...
...
include/ck/utility/amd_wave_read_first_lane.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -7,10 +7,12 @@
#include "ck/utility/functional2.hpp"
#include "ck/utility/math.hpp"
#ifndef CK_CODE_GEN_RTC
#include <array>
#include <cstddef>
#include <cstdint>
#include <type_traits>
#endif
namespace
ck
{
namespace
detail
{
...
...
@@ -37,7 +39,7 @@ struct get_carrier<3>
{
using
value_type
=
uint32_t
;
std
::
a
rray
<
std
::
byte
,
3
>
bytes
;
A
rray
<
ck
::
byte
,
3
>
bytes
;
static_assert
(
sizeof
(
bytes
)
<=
sizeof
(
value_type
));
// replacement of host std::copy_n()
...
...
@@ -61,22 +63,22 @@ struct get_carrier<3>
// method to trigger template substitution failure
__device__
carrier
(
const
carrier
&
other
)
noexcept
{
copy_n
(
other
.
bytes
.
begin
(),
bytes
.
s
ize
(),
bytes
.
begin
());
copy_n
(
other
.
bytes
.
begin
(),
bytes
.
S
ize
(),
bytes
.
begin
());
}
public:
__device__
carrier
&
operator
=
(
value_type
value
)
noexcept
{
copy_n
(
reinterpret_cast
<
const
std
::
byte
*>
(
&
value
),
bytes
.
s
ize
(),
bytes
.
begin
());
copy_n
(
reinterpret_cast
<
const
ck
::
byte
*>
(
&
value
),
bytes
.
S
ize
(),
bytes
.
begin
());
return
*
this
;
}
__device__
operator
value_type
()
const
noexcept
{
std
::
byte
result
[
sizeof
(
value_type
)];
ck
::
byte
result
[
sizeof
(
value_type
)];
copy_n
(
bytes
.
begin
(),
bytes
.
s
ize
(),
result
);
copy_n
(
bytes
.
begin
(),
bytes
.
S
ize
(),
result
);
return
*
reinterpret_cast
<
const
value_type
*>
(
result
);
}
...
...
@@ -109,8 +111,8 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value)
{
constexpr
unsigned
object_size
=
sizeof
(
int64_t
);
constexpr
unsigned
second_part_offset
=
object_size
/
2
;
auto
*
const
from_obj
=
reinterpret_cast
<
const
std
::
byte
*>
(
&
value
);
alignas
(
int64_t
)
std
::
byte
to_obj
[
object_size
];
auto
*
const
from_obj
=
reinterpret_cast
<
const
ck
::
byte
*>
(
&
value
);
alignas
(
int64_t
)
ck
::
byte
to_obj
[
object_size
];
using
Sgpr
=
uint32_t
;
...
...
@@ -122,17 +124,16 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value)
return
*
reinterpret_cast
<
int64_t
*>
(
to_obj
);
}
template
<
typename
Object
,
typename
=
std
::
enable_if_t
<
std
::
is_class_v
<
Object
>
&&
std
::
is_trivially_copyable_v
<
Object
>>>
template
<
typename
Object
,
typename
=
ck
::
enable_if_t
<
ck
::
is_class_v
<
Object
>
&&
ck
::
is_trivially_copyable_v
<
Object
>>>
__device__
auto
amd_wave_read_first_lane
(
const
Object
&
obj
)
{
using
Size
=
unsigned
;
constexpr
Size
SgprSize
=
4
;
constexpr
Size
ObjectSize
=
sizeof
(
Object
);
auto
*
const
from_obj
=
reinterpret_cast
<
const
std
::
byte
*>
(
&
obj
);
alignas
(
Object
)
std
::
byte
to_obj
[
ObjectSize
];
auto
*
const
from_obj
=
reinterpret_cast
<
const
ck
::
byte
*>
(
&
obj
);
alignas
(
Object
)
ck
::
byte
to_obj
[
ObjectSize
];
constexpr
Size
RemainedSize
=
ObjectSize
%
SgprSize
;
constexpr
Size
CompleteSgprCopyBoundary
=
ObjectSize
-
RemainedSize
;
...
...
include/ck/utility/array.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_ARRAY_HPP
#define CK_ARRAY_HPP
...
...
@@ -38,6 +38,8 @@ struct Array
}
__host__
__device__
constexpr
const
TData
*
begin
()
const
{
return
&
mData
[
0
];
}
__host__
__device__
constexpr
const
TData
*
end
()
const
{
return
&
mData
[
NSize
];
}
__host__
__device__
constexpr
TData
*
begin
()
{
return
&
mData
[
0
];
}
__host__
__device__
constexpr
TData
*
end
()
{
return
&
mData
[
NSize
];
}
};
// empty Array
...
...
@@ -54,7 +56,7 @@ template <typename X, typename... Xs>
__host__
__device__
constexpr
auto
make_array
(
X
&&
x
,
Xs
&&
...
xs
)
{
using
data_type
=
remove_cvref_t
<
X
>
;
return
Array
<
data_type
,
sizeof
...(
Xs
)
+
1
>
{
std
::
forward
<
X
>
(
x
),
std
::
forward
<
Xs
>
(
xs
)...};
return
Array
<
data_type
,
sizeof
...(
Xs
)
+
1
>
{
ck
::
forward
<
X
>
(
x
),
ck
::
forward
<
Xs
>
(
xs
)...};
}
// make empty array
...
...
include/ck/utility/container_helper.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_CONTAINER_HELPER_HPP
#define CK_CONTAINER_HELPER_HPP
...
...
@@ -326,14 +326,14 @@ template <typename T, index_t NX, index_t NY>
__host__
__device__
constexpr
auto
container_concat
(
const
Array
<
T
,
NX
>&
ax
,
const
Array
<
T
,
NY
>&
ay
)
{
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
make_array
(
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
ax
,
ay
);
[
&
](
auto
&&
...
zs
)
{
return
make_array
(
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
ax
,
ay
);
}
template
<
typename
...
X
,
typename
...
Y
>
__host__
__device__
constexpr
auto
container_concat
(
const
Tuple
<
X
...
>&
tx
,
const
Tuple
<
Y
...
>&
ty
)
{
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
make_tuple
(
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
tx
,
ty
);
[
&
](
auto
&&
...
zs
)
{
return
make_tuple
(
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
tx
,
ty
);
}
template
<
typename
Container
>
...
...
include/ck/utility/data_type.hpp
View file @
d480a5a6
...
...
@@ -5,9 +5,21 @@
#include "ck/utility/amd_ck_fp8.hpp"
#include "ck/utility/statically_indexed_array.hpp"
#ifdef CK_CODE_GEN_RTC
using
int8_t
=
signed
char
;
using
uint8_t
=
unsigned
char
;
using
int16_t
=
signed
short
;
using
uint16_t
=
unsigned
short
;
using
float_t
=
float
;
#endif
namespace
ck
{
#ifdef CK_CODE_GEN_RTC
using
byte
=
unsigned
char
;
#else
using
std
::
byte
;
#endif
using
bhalf_t
=
ushort
;
using
half_t
=
_Float16
;
using
int4_t
=
_BitInt
(
4
);
...
...
@@ -217,7 +229,7 @@ struct scalar_type<bool>
};
template
<
typename
T
>
struct
vector_type
<
T
,
1
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
1
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
type
=
d1_t
;
...
...
@@ -253,7 +265,7 @@ struct vector_type<T, 1, typename std::enable_if_t<is_native_type<T>()>>
__device__
int
static
err
=
0
;
template
<
typename
T
>
struct
vector_type
<
T
,
2
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
2
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -313,7 +325,7 @@ struct vector_type<T, 2, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
3
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
3
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -383,7 +395,7 @@ struct vector_type<T, 3, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
4
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
4
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -453,7 +465,7 @@ struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
5
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
5
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
...
...
@@ -523,7 +535,7 @@ struct vector_type<T, 5, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
7
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
7
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -605,7 +617,7 @@ struct vector_type<T, 7, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
8
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
8
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -687,7 +699,7 @@ struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
13
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
13
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
...
...
@@ -769,7 +781,7 @@ struct vector_type<T, 13, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
16
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
16
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -863,7 +875,7 @@ struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
32
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
32
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -967,7 +979,7 @@ struct vector_type<T, 32, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
64
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
64
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -1083,7 +1095,7 @@ struct vector_type<T, 64, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
128
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
128
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -1209,7 +1221,7 @@ struct vector_type<T, 128, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
256
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
256
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -1374,7 +1386,7 @@ template <typename T, index_t N>
struct
non_native_vector_base
<
T
,
N
,
std
::
enable_if_t
<
sizeof
(
T
)
==
1
||
sizeof
(
T
)
==
2
||
sizeof
(
T
)
==
4
||
sizeof
(
T
)
==
8
>>
ck
::
enable_if_t
<
sizeof
(
T
)
==
1
||
sizeof
(
T
)
==
2
||
sizeof
(
T
)
==
4
||
sizeof
(
T
)
==
8
>>
{
using
data_t
=
typename
nnvb_data_t_selector
<
T
>::
type
;
// select data_t based on the size of T
static_assert
(
sizeof
(
T
)
==
sizeof
(
data_t
),
"non_native_vector_base storage size mismatch"
);
...
...
@@ -1499,7 +1511,7 @@ struct scalar_type<non_native_vector_base<pk_i4_t, N>>
// non-native vector_type implementation
template
<
typename
T
>
struct
vector_type
<
T
,
1
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
1
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
...
...
@@ -1550,7 +1562,7 @@ struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
2
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
2
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
...
...
@@ -1613,7 +1625,7 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
4
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
4
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
...
...
@@ -1686,7 +1698,7 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
8
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
8
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
...
...
@@ -1771,7 +1783,7 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
16
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
16
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
...
...
@@ -1866,7 +1878,7 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
32
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
32
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
...
...
@@ -1970,7 +1982,7 @@ struct vector_type<T, 32, typename std::enable_if_t<!is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
64
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
64
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
...
...
@@ -2210,20 +2222,230 @@ using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
using
pk_i4x4_t
=
typename
vector_type
<
pk_i4_t
,
4
>::
type
;
using
pk_i4x8_t
=
typename
vector_type
<
pk_i4_t
,
8
>::
type
;
#ifdef CK_CODE_GEN_RTC
template
<
typename
T
>
struct
NumericLimits
;
template
<
>
struct
NumericLimits
<
int32_t
>
{
__host__
__device__
static
constexpr
int32_t
Lowest
()
noexcept
{
return
-
2147483647
-
1
;
}
__host__
__device__
static
constexpr
int32_t
Min
()
noexcept
{
return
-
2147483647
-
1
;
}
__host__
__device__
static
constexpr
int32_t
Max
()
noexcept
{
return
2147483647
;
}
__host__
__device__
static
constexpr
int32_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
int32_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
int16_t
>
{
__host__
__device__
static
constexpr
int16_t
Lowest
()
noexcept
{
return
-
32768
;
}
__host__
__device__
static
constexpr
int16_t
Min
()
noexcept
{
return
-
32768
;
}
__host__
__device__
static
constexpr
int16_t
Max
()
noexcept
{
return
32767
;
}
__host__
__device__
static
constexpr
int16_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
int16_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
int8_t
>
{
__host__
__device__
static
constexpr
int8_t
Lowest
()
noexcept
{
return
-
128
;
}
__host__
__device__
static
constexpr
int8_t
Min
()
noexcept
{
return
-
128
;
}
__host__
__device__
static
constexpr
int8_t
Max
()
noexcept
{
return
127
;
}
__host__
__device__
static
constexpr
int8_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
int8_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
uint32_t
>
{
__host__
__device__
static
constexpr
uint32_t
Lowest
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint32_t
Min
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint32_t
Max
()
noexcept
{
return
4294967295U
;
}
__host__
__device__
static
constexpr
uint32_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint32_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
uint16_t
>
{
__host__
__device__
static
constexpr
uint16_t
Lowest
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint16_t
Min
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint16_t
Max
()
noexcept
{
return
65535U
;
}
__host__
__device__
static
constexpr
uint16_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint16_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
float
>
{
static
constexpr
unsigned
int
binary_min
=
0x00800000
;
static
constexpr
unsigned
int
binary_max
=
0x7F7FFFFF
;
static
constexpr
unsigned
int
binary_lowest
=
0xFF7FFFFF
;
static
constexpr
unsigned
int
binary_qnan
=
0xFFC00001
;
static
constexpr
unsigned
int
binary_inf
=
0x7F8000000
;
__host__
__device__
static
constexpr
float
Min
()
{
return
bit_cast
<
float
>
(
binary_min
);
}
__host__
__device__
static
constexpr
float
Max
()
{
return
bit_cast
<
float
>
(
binary_max
);
}
__host__
__device__
static
constexpr
float
Lowest
()
{
return
bit_cast
<
float
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
float
QuietNaN
()
{
return
bit_cast
<
float
>
(
binary_qnan
);
}
__host__
__device__
static
constexpr
float
Infinity
()
{
return
bit_cast
<
float
>
(
binary_inf
);
}
};
template
<
>
struct
NumericLimits
<
half_t
>
{
static
constexpr
unsigned
short
binary_min
=
0x0400
;
static
constexpr
unsigned
short
binary_max
=
0x7BFF
;
static
constexpr
unsigned
short
binary_lowest
=
0xFBFF
;
static
constexpr
unsigned
short
binary_qnan
=
0x7FFF
;
__host__
__device__
static
constexpr
half_t
Min
()
{
return
bit_cast
<
half_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
half_t
Max
()
{
return
bit_cast
<
half_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
half_t
Lowest
()
{
return
bit_cast
<
half_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
half_t
QuietNaN
()
{
return
bit_cast
<
half_t
>
(
binary_qnan
);
}
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
struct
NumericLimits
<
int4_t
>
{
__host__
__device__
static
constexpr
int4_t
Min
()
{
return
int4_t
(
-
8
);
}
__host__
__device__
static
constexpr
int4_t
Max
()
{
return
int4_t
(
7
);
}
__host__
__device__
static
constexpr
int4_t
Lowest
()
{
return
int4_t
(
-
8
);
}
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
struct
NumericLimits
<
f8_fnuz_t
>
{
// negative zero nan mode with exp bias = 8
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000
static
constexpr
uint8_t
binary_max
=
0x7F
;
// 0b01111111
static
constexpr
uint8_t
binary_lowest
=
0xFF
;
// 0b11111111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
// ieee mode with exp bias = 7
// static constexpr uint8_t binary_min = 0x08; // 0b00001000
// static constexpr uint8_t binary_max = 0x77; // 0b01110111
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__
__device__
static
constexpr
f8_fnuz_t
Min
()
{
return
f8_fnuz_t
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_fnuz_t
Max
()
{
return
f8_fnuz_t
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_fnuz_t
Lowest
()
{
return
f8_fnuz_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_fnuz_t
QuietNaN
()
{
return
f8_fnuz_t
(
binary_qnan
);
}
};
template
<
>
struct
NumericLimits
<
bf8_fnuz_t
>
{
// negative zero nan mode with exp bias = 16
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100
static
constexpr
uint8_t
binary_max
=
0x7F
;
// 0b01111111
static
constexpr
uint8_t
binary_lowest
=
0xFF
;
// 0b11111111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
// ieee mode with exp bias = 15
// static constexpr uint8_t binary_min = 0x04; // 0b00000100
// static constexpr uint8_t binary_max = 0x7B; // 0b01111011
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__
__device__
static
constexpr
bf8_fnuz_t
Min
()
{
return
bf8_fnuz_t
(
binary_min
);
}
__host__
__device__
static
constexpr
bf8_fnuz_t
Max
()
{
return
bf8_fnuz_t
(
binary_max
);
}
__host__
__device__
static
constexpr
bf8_fnuz_t
Lowest
()
{
return
bf8_fnuz_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
bf8_fnuz_t
QuietNaN
()
{
return
bf8_fnuz_t
(
binary_qnan
);
}
};
template
<
>
struct
NumericLimits
<
f8_ocp_t
>
{
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000 = 2^-6
static
constexpr
uint8_t
binary_max
=
0x7E
;
// 0b01111110 = 448
static
constexpr
uint8_t
binary_lowest
=
0xFE
;
// 0b11111110 = -448
static
constexpr
uint8_t
binary_qnan
=
0x7F
;
// 0b01111111
__host__
__device__
static
constexpr
f8_ocp_t
Min
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_ocp_t
Max
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_ocp_t
Lowest
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_ocp_t
QuietNaN
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_qnan
);
}
};
template
<
>
struct
NumericLimits
<
bf8_ocp_t
>
{
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100 = 2^-14
static
constexpr
uint8_t
binary_max
=
0x7B
;
// 0b01111011 = 57344
static
constexpr
uint8_t
binary_lowest
=
0xFB
;
// 0b11111011 = -57344
static
constexpr
uint8_t
binary_qnan
=
0x7D
;
// 0b01111101
__host__
__device__
static
constexpr
bf8_ocp_t
Min
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
bf8_ocp_t
Max
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
bf8_ocp_t
Lowest
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
bf8_ocp_t
QuietNaN
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_qnan
);
}
};
#else
template
<
typename
T
>
struct
NumericLimits
{
__host__
__device__
static
constexpr
T
Min
()
{
return
std
::
numeric_limits
<
T
>::
min
();
}
__host__
__device__
static
constexpr
T
Max
()
{
return
std
::
numeric_limits
<
T
>::
max
();
}
__host__
__device__
static
constexpr
T
Lowest
()
{
return
std
::
numeric_limits
<
T
>::
lowest
();
}
__host__
__device__
static
constexpr
T
QuietNaN
()
{
return
std
::
numeric_limits
<
T
>::
quiet_NaN
();
}
__host__
__device__
static
constexpr
T
Infinity
()
{
return
std
::
numeric_limits
<
T
>::
infinity
();
}
};
...
...
@@ -2347,6 +2569,7 @@ struct NumericLimits<bf8_ocp_t>
return
bit_cast
<
bf8_ocp_t
>
(
binary_qnan
);
}
};
#endif
template
<
typename
T
>
struct
NumericUtils
...
...
include/ck/utility/debug.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP
#include "type.hpp"
namespace
ck
{
namespace
debug
{
...
...
include/ck/utility/enable_if.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
{
#ifndef CK_CODE_GEN_RTC
template
<
bool
B
,
typename
T
=
void
>
using
enable_if
=
std
::
enable_if
<
B
,
T
>
;
template
<
bool
B
,
typename
T
=
void
>
using
enable_if_t
=
typename
std
::
enable_if
<
B
,
T
>::
type
;
#else
template
<
bool
B
,
class
T
=
void
>
struct
enable_if
{
};
template
<
class
T
>
struct
enable_if
<
true
,
T
>
{
using
type
=
T
;
};
template
<
bool
B
,
class
T
=
void
>
using
enable_if_t
=
typename
enable_if
<
B
,
T
>::
type
;
#endif
}
// namespace ck
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment