Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
0c823497
Commit
0c823497
authored
Nov 10, 2023
by
muozturk
Browse files
merge
parents
334cfe1c
68f2b5e7
Changes
415
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1418 additions
and
989 deletions
+1418
-989
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+181
-14
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_3d.hpp
.../ck/tensor_operation/gpu/grid/gridwise_elementwise_3d.hpp
+264
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
+3
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+27
-14
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+109
-55
include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
...k/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
+32
-11
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp
...d/normalization/gridwise_normalization_naive_variance.hpp
+112
-5
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp
...pu/grid/normalization/gridwise_normalization_selector.hpp
+50
-16
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp
.../grid/normalization/gridwise_normalization_splitk_2nd.hpp
+85
-4
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp
...normalization/gridwise_normalization_welford_variance.hpp
+110
-7
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+54
-23
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
+55
-89
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+4
-28
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+210
-628
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+2
-13
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+9
-26
include/ck/utility/f8_utils.hpp
include/ck/utility/f8_utils.hpp
+98
-53
include/ck/utility/inner_product.hpp
include/ck/utility/inner_product.hpp
+2
-0
include/ck/utility/is_detected.hpp
include/ck/utility/is_detected.hpp
+9
-0
No files found.
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
0c823497
...
...
@@ -16,6 +16,57 @@ namespace element_wise {
extern
"C"
__device__
float
__ocml_native_recip_f32
(
float
);
#endif
struct
PassThroughPack2
{
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
__host__
__device__
constexpr
void
operator
()(
ck
::
f8x2_t
&
y
,
const
ck
::
half2_t
&
x
)
const
{
// fake conversion
uint16_t
t
=
ck
::
bit_cast
<
uint32_t
>
(
x
);
y
=
ck
::
bit_cast
<
ck
::
f8x2_t
>
(
t
);
}
__host__
__device__
constexpr
void
operator
()(
ck
::
half2_t
&
y
,
const
ck
::
f8x2_t
&
x
)
const
{
auto
t
=
type_convert
<
float2_t
>
(
x
);
y
=
type_convert
<
half2_t
>
(
t
);
}
__host__
__device__
constexpr
void
operator
()(
ck
::
half2_t
&
y
,
const
ck
::
half2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
f8x2_t
&
y
,
const
ck
::
f8x2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
float2_t
&
y
,
const
ck
::
float2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
int8x2_t
&
y
,
const
ck
::
int8x2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
bhalf2_t
&
y
,
const
ck
::
bhalf2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
double2_t
&
y
,
const
ck
::
double2_t
&
x
)
const
{
y
=
x
;
}
constexpr
const
static
bool
is_pack2_invocable
=
true
;
};
struct
PassThrough
{
template
<
typename
Y
,
typename
X
>
...
...
@@ -33,6 +84,12 @@ struct PassThrough
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
double
,
float
>
(
double
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
double
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
...
...
@@ -69,6 +126,12 @@ struct PassThrough
y
=
type_convert
<
bhalf_t
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
bhalf_t
>
(
float
&
y
,
const
bhalf_t
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
half_t
>
(
bhalf_t
&
y
,
const
half_t
&
x
)
const
{
...
...
@@ -113,7 +176,6 @@ struct PassThrough
}
#endif
#if defined CK_ENABLE_FP8
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
f8_t
>
(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
...
...
@@ -143,9 +205,7 @@ struct PassThrough
{
y
=
type_convert
<
f8_t
>
(
x
);
}
#endif
#if defined CK_ENABLE_BF8
template
<
>
__host__
__device__
void
operator
()
<
bf8_t
,
bf8_t
>
(
bf8_t
&
y
,
const
bf8_t
&
x
)
const
{
...
...
@@ -173,10 +233,8 @@ struct PassThrough
template
<
>
__host__
__device__
void
operator
()
<
bf8_t
,
half_t
>
(
bf8_t
&
y
,
const
half_t
&
x
)
const
{
// to-do: fix half_t to bf8_t convert
y
=
ck
::
type_convert
<
bf8_t
>
(
ck
::
type_convert
<
float
>
(
x
));
y
=
ck
::
type_convert
<
bf8_t
>
(
x
);
}
#endif
};
struct
UnaryConvert
...
...
@@ -205,7 +263,6 @@ struct ConvertBF16RTN
}
};
#if defined CK_ENABLE_FP8
struct
ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
...
...
@@ -213,7 +270,8 @@ struct ConvertF8SR
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
// check Y datatype
static_assert
(
is_same
<
Y
,
f8_t
>::
value
,
"Data type is not supported by this operation!"
);
static_assert
(
is_same
<
Y
,
f8_t
>::
value
||
is_same
<
Y
,
bf8_t
>::
value
,
"Data type is not supported by this operation!"
);
// check X datatype
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
half_t
>::
value
,
...
...
@@ -222,7 +280,6 @@ struct ConvertF8SR
y
=
f8_convert_sr
<
Y
>
(
x
);
}
};
#endif
struct
Scale
{
...
...
@@ -231,6 +288,20 @@ struct Scale
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
ck
::
type_convert
<
half_t
>
(
scale_
)
*
x
;
};
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
const
float
x_tmp
=
ck
::
type_convert
<
float
>
(
x
);
const
float
y_tmp
=
scale_
*
x_tmp
;
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_tmp
);
};
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
...
...
@@ -449,10 +520,11 @@ struct Sigmoid
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
,
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
1
/
(
ck
::
type_convert
<
T
>
(
1
)
+
exp
(
-
x
));
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
};
};
...
...
@@ -462,7 +534,8 @@ struct TanH
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
,
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
tanh
(
x
);
...
...
@@ -488,7 +561,101 @@ struct Swish
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
};
float
beta_
=
1.0
f
;
const
float
beta_
;
};
struct
SoftRelu
{
SoftRelu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
const
float
alpha_
;
};
struct
Power
{
Power
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
,
float
gamma
=
2.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
),
gamma_
(
gamma
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
T
casted_gamma
=
type_convert
<
T
>
(
gamma_
);
T
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
const
float
alpha_
;
const
float
beta_
;
const
float
gamma_
;
};
struct
ClippedRelu
{
ClippedRelu
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
const
float
alpha_
;
const
float
beta_
;
};
struct
LeakyRelu
{
LeakyRelu
(
float
alpha
=
0.01
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
const
float
alpha_
;
};
struct
Elu
{
Elu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
const
float
alpha_
;
};
}
// namespace element_wise
...
...
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_3d.hpp
0 → 100644
View file @
0c823497
// SPDX-License-Identifier: MIT
// // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
//
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseElementwise3dFunctor
,
typename
InGrid3dDescTuple
,
typename
OutGrid3dDescTuple
,
typename
InDataTypePointerTuple
,
typename
OutDataTypePointerTuple
,
typename
ElementwiseOperation
>
__global__
void
kernel_elementwise_3d
(
const
InGrid3dDescTuple
in_grid_3d_desc_tuple
,
const
OutGrid3dDescTuple
out_grid_3d_desc_tuple
,
const
InDataTypePointerTuple
p_in_global_tuple
,
const
OutDataTypePointerTuple
p_out_global_tuple
,
const
ElementwiseOperation
elementwise_op
,
const
index_t
num_threads_m
,
const
index_t
num_threads_n
,
const
index_t
num_threads_k
)
{
GridwiseElementwise3dFunctor
::
Run
(
in_grid_3d_desc_tuple
,
out_grid_3d_desc_tuple
,
p_in_global_tuple
,
p_out_global_tuple
,
elementwise_op
,
num_threads_m
,
num_threads_n
,
num_threads_k
);
}
template
<
typename
InGrid3dDescTuple
,
typename
OutGrid3dDescTuple
,
typename
InDataTypePointerTuple
,
typename
OutDataTypePointerTuple
,
typename
ElementwiseOperation
,
index_t
MPerThread
,
index_t
NPerThread
,
index_t
KPerThread
,
typename
InScalarPerVectorSeq
,
typename
OutScalarPerVectorSeq
>
struct
GridwiseElementwise_3D
{
static
constexpr
index_t
NumInput
=
InDataTypePointerTuple
::
Size
();
static
constexpr
index_t
NumOutput
=
OutDataTypePointerTuple
::
Size
();
static_assert
(
NumInput
==
InScalarPerVectorSeq
::
Size
()
&&
NumOutput
==
OutScalarPerVectorSeq
::
Size
()
&&
NumInput
==
InGrid3dDescTuple
::
Size
()
&&
NumOutput
==
OutGrid3dDescTuple
::
Size
(),
"Tuple size is inconsistent with the number of in/out!"
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
thread_buffer_desc_mnk
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MPerThread
>
{},
Number
<
NPerThread
>
{},
Number
<
KPerThread
>
{}));
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
__device__
static
void
Run
(
const
InGrid3dDescTuple
in_grid_3d_desc_tuple
,
const
OutGrid3dDescTuple
out_grid_3d_desc_tuple
,
const
InDataTypePointerTuple
p_in_global_tuple
,
const
OutDataTypePointerTuple
p_out_global_tuple
,
const
ElementwiseOperation
elementwise_op
,
const
index_t
num_threads_m
,
const
index_t
num_threads_n
,
const
index_t
num_threads_k
)
{
auto
in_thread_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
InDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_cv_t
<
remove_pointer_t
<
DataTypePointer
>>
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
,
MPerThread
*
NPerThread
*
KPerThread
,
true
>
{};
},
Number
<
NumInput
>
{});
auto
out_thread_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
OutDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_pointer_t
<
DataTypePointer
>
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
,
MPerThread
*
NPerThread
*
KPerThread
,
true
>
{};
},
Number
<
NumOutput
>
{});
auto
in_global_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global_tuple
[
I
],
in_grid_3d_desc_tuple
[
I
].
GetElementSpaceSize
());
},
Number
<
NumInput
>
{});
auto
out_global_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global_tuple
[
I
],
out_grid_3d_desc_tuple
[
I
].
GetElementSpaceSize
());
},
Number
<
NumOutput
>
{});
const
auto
M
=
in_grid_3d_desc_tuple
[
I0
].
GetLength
(
I0
);
const
auto
N
=
in_grid_3d_desc_tuple
[
I0
].
GetLength
(
I1
);
const
auto
K
=
in_grid_3d_desc_tuple
[
I0
].
GetLength
(
I2
);
const
index_t
loop_step_m
=
num_threads_m
*
MPerThread
;
const
index_t
loop_step_n
=
num_threads_n
*
NPerThread
;
const
index_t
loop_step_k
=
num_threads_k
*
KPerThread
;
const
index_t
thread_1d_id
=
get_thread_global_1d_id
();
const
index_t
tid_m
=
thread_1d_id
/
(
num_threads_n
*
num_threads_k
);
const
index_t
tid_nk
=
thread_1d_id
%
(
num_threads_n
*
num_threads_k
);
const
index_t
tid_n
=
tid_nk
/
num_threads_k
;
const
index_t
tid_k
=
tid_nk
%
num_threads_k
;
const
auto
thread_global_offset
=
make_multi_index
(
tid_m
*
MPerThread
,
tid_n
*
NPerThread
,
tid_k
*
KPerThread
);
auto
in_global_load_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
InDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_cv_t
<
remove_pointer_t
<
DataTypePointer
>>
;
return
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
DataType
,
decltype
(
in_grid_3d_desc_tuple
[
I
]),
decltype
(
thread_buffer_desc_mnk
),
Sequence
<
MPerThread
,
NPerThread
,
KPerThread
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
>
,
// DimAccessOrder
01
,
// SrcVectorDim
InScalarPerVectorSeq
::
At
(
I
),
// InScalarPerVectorSeq::At(I), //
// ScalarPerVector
1
,
// SrcScalarStrideInVector
true
>
{
in_grid_3d_desc_tuple
[
I
],
thread_global_offset
};
},
Number
<
NumInput
>
{});
auto
out_global_store_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataTypePointer
=
remove_cvref_t
<
decltype
(
OutDataTypePointerTuple
{}[
I
])
>
;
using
DataType
=
remove_pointer_t
<
DataTypePointer
>
;
return
ThreadwiseTensorSliceTransfer_v1r3
<
DataType
,
DataType
,
decltype
(
thread_buffer_desc_mnk
),
decltype
(
out_grid_3d_desc_tuple
[
I
]),
PassThroughOp
,
Sequence
<
MPerThread
,
NPerThread
,
KPerThread
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
>
,
// DimAccessOrder
2
,
// SrcVectorDim
OutScalarPerVectorSeq
::
At
(
I
),
// OutScalarPerVectorSeq::At(I),
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
out_grid_3d_desc_tuple
[
I
],
thread_global_offset
,
PassThroughOp
{});
},
Number
<
NumOutput
>
{});
index_t
num_iter_m
=
M
/
(
loop_step_m
);
do
{
index_t
num_iter_n
=
N
/
(
loop_step_n
);
do
{
index_t
num_iter_k
=
K
/
(
loop_step_k
);
do
{
static_for
<
0
,
NumInput
,
1
>
{}([
&
](
auto
I
)
{
in_global_load_tuple
(
I
).
Run
(
in_grid_3d_desc_tuple
[
I
],
in_global_buf_tuple
[
I
],
thread_buffer_desc_mnk
,
make_tuple
(
I0
,
I0
,
I0
),
in_thread_buf_tuple
(
I
));
in_global_load_tuple
(
I
).
MoveSrcSliceWindow
(
in_grid_3d_desc_tuple
[
I
],
make_multi_index
(
0
,
0
,
loop_step_k
));
});
static_for
<
0
,
MPerThread
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
NPerThread
,
1
>
{}([
&
](
auto
iN
)
{
static_for
<
0
,
KPerThread
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc_mnk
.
CalculateOffset
(
make_tuple
(
iM
,
iN
,
iK
));
// get reference to in data
const
auto
in_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
I
)
->
const
auto
&
{
return
in_thread_buf_tuple
(
I
)(
Number
<
offset
>
{});
},
Number
<
NumInput
>
{});
// get referenec to dst data
auto
out_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
I
)
->
auto
&
{
return
out_thread_buf_tuple
(
I
)(
Number
<
offset
>
{});
},
Number
<
NumOutput
>
{});
unpack2
(
elementwise_op
,
out_data_refs
,
in_data_refs
);
});
});
});
static_for
<
0
,
NumOutput
,
1
>
{}([
&
](
auto
I
)
{
out_global_store_tuple
(
I
).
Run
(
thread_buffer_desc_mnk
,
make_tuple
(
I0
,
I0
,
I0
),
out_thread_buf_tuple
[
I
],
out_grid_3d_desc_tuple
[
I
],
out_global_buf_tuple
(
I
));
out_global_store_tuple
(
I
).
MoveDstSliceWindow
(
out_grid_3d_desc_tuple
[
I
],
make_multi_index
(
0
,
0
,
loop_step_k
));
});
}
while
(
--
num_iter_k
);
static_for
<
0
,
NumInput
,
1
>
{}([
&
](
auto
I
)
{
in_global_load_tuple
(
I
).
MoveSrcSliceWindow
(
in_grid_3d_desc_tuple
[
I
],
make_multi_index
(
0
,
loop_step_n
,
-
(
K
/
loop_step_k
)
*
loop_step_k
));
});
static_for
<
0
,
NumOutput
,
1
>
{}([
&
](
auto
I
)
{
out_global_store_tuple
(
I
).
MoveDstSliceWindow
(
out_grid_3d_desc_tuple
[
I
],
make_multi_index
(
0
,
loop_step_n
,
-
(
K
/
loop_step_k
)
*
loop_step_k
));
});
}
while
(
--
num_iter_n
);
static_for
<
0
,
NumInput
,
1
>
{}([
&
](
auto
I
)
{
in_global_load_tuple
(
I
).
MoveSrcSliceWindow
(
in_grid_3d_desc_tuple
[
I
],
make_multi_index
(
loop_step_m
,
-
(
N
/
loop_step_n
)
*
loop_step_n
,
-
(
K
/
loop_step_k
)
*
loop_step_k
));
});
static_for
<
0
,
NumOutput
,
1
>
{}([
&
](
auto
I
)
{
out_global_store_tuple
(
I
).
MoveDstSliceWindow
(
out_grid_3d_desc_tuple
[
I
],
make_multi_index
(
loop_step_m
,
-
(
N
/
loop_step_n
)
*
loop_step_n
,
-
(
K
/
loop_step_k
)
*
loop_step_k
));
});
}
while
(
--
num_iter_m
);
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
View file @
0c823497
...
...
@@ -428,7 +428,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
[
&
](
auto
i
)
{
using
ALayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
AsLayout
>>
;
return
MakeAGridDescriptor_M_
N
<
ALayout
,
GemmSpec
>
(
MRaws
[
i
],
KRaws
[
i
],
AsStride
[
i
]);
return
MakeAGridDescriptor_M_
K
<
ALayout
,
GemmSpec
>
(
MRaws
[
i
],
KRaws
[
i
],
AsStride
[
i
]);
},
Number
<
NumATensor
>
{});
}
...
...
@@ -656,7 +656,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
ComputeDataType
,
ComputeDataType
,
// ComputeDataType for A
ComputeDataType
,
// ComputeDataType for B
AccDataType
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
0c823497
...
...
@@ -36,7 +36,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_conv_
fwd_
multiple_d_wmma_cshuffle
(
kernel_grouped_conv_multiple_d_wmma_cshuffle
(
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
DsPointer
p_ds_grid
,
...
...
@@ -452,11 +452,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// CheckValidity for kernels without multi D
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
...
...
@@ -471,18 +471,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
bool
valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
valid
=
valid
&&
(
M
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I0
)
&&
N
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I1
));
});
if
(
!
valid
)
{
return
false
;
}
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)))
...
...
@@ -517,6 +505,31 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
return
true
;
}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
bool
valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
valid
=
valid
&&
(
M
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I0
)
&&
N
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I1
));
});
if
(
!
valid
)
{
return
false
;
}
return
CheckValidity
(
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
e_grid_desc_m_n
,
block_2_ctile_map
);
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
0c823497
...
...
@@ -945,7 +945,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
}
}();
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
0c823497
...
...
@@ -22,13 +22,19 @@ namespace ck {
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
Block2CTileMap
>
typename
Block2CTileMap
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_xdlops_v2r4r2_simplified
(
typename
GridwiseGemm
::
Argument
karg
,
const
Block2CTileMap
&
b2c_map
)
const
Block2CTileMap
&
b2c_map
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...
...
@@ -37,10 +43,13 @@ __global__ void
__shared__
uint8_t
p_shared
[
shared_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
karg
,
static_cast
<
void
*>
(
p_shared
),
b2c_map
);
karg
,
static_cast
<
void
*>
(
p_shared
),
b2c_map
,
a_element_op
,
b_element_op
,
c_element_op
);
#else
ignore
=
karg
;
ignore
=
b2c_map
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
...
...
@@ -127,7 +136,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t
MPadded
;
index_t
NPadded
;
index_t
KPadded
;
index_t
K0
;
index_t
K0
Padded
;
index_t
k_batch
;
Argument
(
const
FloatA
*
p_a_grid_
,
...
...
@@ -142,7 +151,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t
MPadded_
,
index_t
NPadded_
,
index_t
KPadded_
,
index_t
K0_
,
index_t
K0
Padded
_
,
index_t
k_batch_
)
:
p_a_grid
(
p_a_grid_
),
p_b_grid
(
p_b_grid_
),
...
...
@@ -156,7 +165,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
MPadded
(
MPadded_
),
NPadded
(
NPadded_
),
KPadded
(
KPadded_
),
K0
(
K0
_
),
K0
Padded
(
K0Padded
_
),
k_batch
(
k_batch_
)
{
}
...
...
@@ -173,7 +182,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<<
"MP:"
<<
MPadded
<<
", "
<<
"NP:"
<<
NPadded
<<
", "
<<
"KP:"
<<
KPadded
<<
", "
<<
"K0:"
<<
K0
<<
", "
<<
"K0
Padded
:"
<<
K0
Padded
<<
", "
<<
"KB:"
<<
k_batch
<<
"}"
<<
std
::
endl
;
}
};
...
...
@@ -196,7 +205,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
return
math
::
integer_least_multiple
(
N
,
NPerBlock
);
}
__host__
__device__
static
auto
CalculateK0
(
index_t
K
,
index_t
K_Batch
=
1
)
__host__
__device__
static
auto
CalculateK0
Padded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
// k_batch * k0 * k0_per_block * k1
auto
K_t
=
K_Batch
*
K0PerBlock
*
K1
;
...
...
@@ -205,8 +214,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__host__
__device__
static
auto
CalculateKPadded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
auto
K0
=
CalculateK0
(
K
,
K_Batch
);
return
K_Batch
*
K0
*
K1
;
auto
K0
Padded
=
CalculateK0
Padded
(
K
,
K_Batch
);
return
K_Batch
*
K0
Padded
*
K1
;
}
__host__
__device__
static
auto
MakeAGridDescriptor_KBatch_K0_M_K1
(
index_t
M
,
...
...
@@ -214,7 +223,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t
K
,
index_t
StrideA
,
index_t
KBatch
,
index_t
K0
,
index_t
K0
Padded
,
index_t
KPad
)
{
const
auto
a_grid_desc_m_k
=
[
&
]()
{
...
...
@@ -228,21 +237,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
}
}();
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
const
auto
a_grid_desc_m_kpad
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return
transform_tensor_descriptor
(
a_grid_desc_m_kpad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_right_pad_transform
(
M
,
MPad
-
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_right_pad_transform
(
M
,
MPad
-
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -250,8 +271,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_k
pad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1
)),
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
Padded
,
K1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -263,7 +284,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
index_t
N
,
index_t
StrideB
,
index_t
KBatch
,
index_t
K0
,
index_t
K0
Padded
,
index_t
KPad
)
{
const
auto
b_grid_desc_k_n
=
[
&
]()
{
...
...
@@ -277,21 +298,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
}
}();
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
const
auto
b_grid_desc_kpad_n
=
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_right_pad_transform
(
K
,
KPad
-
K
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return
transform_tensor_descriptor
(
b_grid_desc_kpad_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -299,8 +332,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
else
{
return
transform_tensor_descriptor
(
b_grid_desc_k
pad
_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1
)),
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
Padded
,
K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -389,6 +422,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
return
false
;
}
}
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
...
...
@@ -401,6 +435,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
auto
K_t
=
karg
.
k_batch
*
K0PerBlock
*
K1
;
if
(
!
(
karg
.
K
%
K_t
==
0
))
{
#if DEBUG_LOG
std
::
cout
<<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<<
karg
.
K
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
...
...
@@ -469,9 +522,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
if
(
karg
.
N
%
CBlockTransferScalarPerVector_NWaveNPerXDL
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
") value is not a multiple of
CBlockTransferScalarPerVector_NWaveNPerXDL ("
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
") value is not a multiple of "
"
CBlockTransferScalarPerVector_NWaveNPerXDL ("
<<
CBlockTransferScalarPerVector_NWaveNPerXDL
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
...
...
@@ -484,9 +537,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
if
(
karg
.
M
%
CBlockTransferScalarPerVector_NWaveNPerXDL
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
") value is not a multiple of
CBlockTransferScalarPerVector_NWaveNPerXDL ("
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
") value is not a multiple of "
"
CBlockTransferScalarPerVector_NWaveNPerXDL ("
<<
CBlockTransferScalarPerVector_NWaveNPerXDL
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
...
...
@@ -495,14 +548,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
}
}
const
auto
num_k_loop
=
karg
.
K0
/
K0PerBlock
;
const
auto
num_k_loop
=
karg
.
K0
Padded
/
K0PerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
#if DEBUG_LOG
std
::
cout
<<
"The number of k loops ("
<<
num_k_loop
<<
") value is not supported by GridwiseGemm Pipeline."
<<
" K0: "
<<
karg
.
K0
<<
", K0PerBlock: "
<<
K0PerBlock
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
<<
" K0
Padded
: "
<<
karg
.
K0
Padded
<<
", K0PerBlock: "
<<
K0PerBlock
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
...
...
@@ -512,14 +565,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__host__
__device__
static
auto
GetKPad
(
index_t
K
,
index_t
KBatch
)
{
const
index_t
K0
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
*
KBatch
)
*
K0PerBlock
;
const
index_t
KPad
=
KBatch
*
K0
*
K1
;
const
index_t
K0Padded
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
*
KBatch
)
*
K0PerBlock
;
const
index_t
KPad
=
KBatch
*
K0Padded
*
K1
;
return
KPad
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
Padded
)
{
const
index_t
num_loop
=
K0
/
K0PerBlock
;
const
index_t
num_loop
=
K0
Padded
/
K0PerBlock
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
...
...
@@ -577,22 +631,22 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
Argument
&
karg
,
void
*
__restrict__
p_shared_block
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
,
const
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
const
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
const
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{})
{
const
FloatA
*
p_a_grid
=
karg
.
p_a_grid
;
const
FloatB
*
p_b_grid
=
karg
.
p_b_grid
;
FloatC
*
p_c_grid
=
karg
.
p_c_grid
;
const
auto
a_b_k0_m_k1_grid_desc
=
MakeAGridDescriptor_KBatch_K0_M_K1
(
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
StrideA
,
karg
.
k_batch
,
karg
.
K0
,
karg
.
KPadded
);
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
StrideA
,
karg
.
k_batch
,
karg
.
K0
Padded
,
karg
.
KPadded
);
const
auto
b_b_k0_n_k1_grid_desc
=
MakeBGridDescriptor_KBatch_K0_N_K1
(
karg
.
K
,
karg
.
NPadded
,
karg
.
N
,
karg
.
StrideB
,
karg
.
k_batch
,
karg
.
K0
,
karg
.
KPadded
);
karg
.
K
,
karg
.
NPadded
,
karg
.
N
,
karg
.
StrideB
,
karg
.
k_batch
,
karg
.
K0
Padded
,
karg
.
KPadded
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
const
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{};
const
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{};
const
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{};
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_b_k0_m_k1_grid_desc
.
GetElementSpaceSize
());
...
...
@@ -761,8 +815,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
ComputeType
,
ComputeType
,
ComputeType
,
// ComputeType A
ComputeType
,
// ComputeType B
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
View file @
0c823497
...
...
@@ -21,6 +21,7 @@ template <typename InputGridDesc,
typename
OutputGridDesc
,
typename
OutputDataType
,
typename
Block2ETileMap
,
typename
ComputePtrOffsetOfStridedBatch
,
typename
GridwiseTensorRearrangeKernel
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
...
...
@@ -30,13 +31,20 @@ __global__ void
const
InputDataType
*
__restrict__
p_in_global
,
const
OutputGridDesc
out_grid_desc
,
OutputDataType
*
__restrict__
p_out_global
,
const
Block2ETileMap
block_2_tile_map
)
const
index_t
batch_count
,
const
Block2ETileMap
block_2_tile_map
,
const
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
GridwiseTensorRearrangeKernel
::
Run
(
in_grid_desc
,
p_in_global
,
out_grid_desc
,
p_out_global
,
block_2_tile_map
);
GridwiseTensorRearrangeKernel
::
Run
(
in_grid_desc
,
p_in_global
,
out_grid_desc
,
p_out_global
,
batch_count
,
block_2_tile_map
,
compute_ptr_offset_of_batch
);
#else
ignore
=
in_grid_desc
;
ignore
=
p_in_global
;
...
...
@@ -56,7 +64,8 @@ template <typename InputGridDesc,
typename
ThreadClusterLengths
,
index_t
ScalarPerVector
,
InMemoryDataOperationEnum
DstInMemOp
,
typename
Block2ETileMap
>
typename
Block2ETileMap
,
typename
ComputePtrOffsetOfStridedBatch
>
struct
GridwiseTensorRearrange
{
...
...
@@ -69,7 +78,9 @@ struct GridwiseTensorRearrange
const
InputDataType
*
__restrict__
p_in_global
,
const
OutputGridDesc
&
out_grid_desc
,
OutputDataType
*
__restrict__
p_out_global
,
const
Block2ETileMap
&
block_2_tile_map
)
const
index_t
batch_count
,
const
Block2ETileMap
&
block_2_tile_map
,
const
ComputePtrOffsetOfStridedBatch
&
compute_ptr_offset_of_batch
)
{
const
auto
block_work_idx
=
block_2_tile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
...
...
@@ -80,12 +91,6 @@ struct GridwiseTensorRearrange
const
index_t
k_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
KPerBlock
);
// Global Memory
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global
,
in_grid_desc
.
GetElementSpaceSize
());
auto
out_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global
,
out_grid_desc
.
GetElementSpaceSize
());
auto
copy_global_to_global
=
ThreadGroupTensorSliceTransfer_v7
<
ThisThreadBlock
,
Tuple
<
InputDataType
>
,
...
...
@@ -108,6 +113,22 @@ struct GridwiseTensorRearrange
make_tuple
(
make_multi_index
(
m_block_data_idx_on_grid
,
k_block_data_idx_on_grid
)),
tensor_operation
::
element_wise
::
PassThrough
{}};
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
// Global Memory
const
index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
));
const
index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
));
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global
+
a_batch_offset
,
in_grid_desc
.
GetElementSpaceSize
());
auto
out_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global
+
c_batch_offset
,
out_grid_desc
.
GetElementSpaceSize
());
copy_global_to_global
.
Run
(
tie
(
in_grid_desc
),
tie
(
in_global_buf
),
tie
(
out_grid_desc
),
tie
(
out_global_buf
));
}
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp
View file @
0c823497
...
...
@@ -18,9 +18,11 @@ template <typename XDataType,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_M
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
...
...
@@ -34,6 +36,7 @@ template <typename XDataType,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
,
bool
SweepOnce
>
struct
GridwiseNormalizationNaiveVariance_mk_to_mk
{
...
...
@@ -45,6 +48,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
(
YDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
(
MThreadSliceSize
%
SaveMeanInvStdDstVectorSize
==
0
,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!"
);
static_assert
(
XSrcVectorSize
==
YDstVectorSize
);
static_assert
(
XSrcVectorSize
==
GammaSrcVectorSize
);
static_assert
(
XSrcVectorSize
==
BetaSrcVectorSize
);
...
...
@@ -66,6 +73,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{}));
using
ThreadBufferLengths_M
=
Sequence
<
MThreadSliceSize
>
;
static
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{})));
using
ThreadReduceDstDesc_M
=
...
...
@@ -84,6 +95,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
reduce
::
Add
,
true
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
@@ -98,12 +111,16 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
const
GridDesc_M_K
&
gamma_grid_desc_m_k
,
const
GridDesc_M_K
&
beta_grid_desc_m_k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
const
GridDesc_M
&
save_mean_grid_desc_m
,
const
GridDesc_M
&
save_inv_std_grid_desc_m
,
index_t
num_k_block_tile_iteration
,
ComputeDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_mean_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_inv_std_global
,
const
YElementwiseOperation
y_elementwise_op
)
{
// LDS
...
...
@@ -115,6 +132,12 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
auto
save_mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_save_mean_global
,
save_mean_grid_desc_m
.
GetElementSpaceSize
());
auto
save_inv_std_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_save_inv_std_global
,
save_inv_std_grid_desc_m
.
GetElementSpaceSize
());
auto
x_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
...
...
@@ -152,6 +175,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
mean_square_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>&
var_thread_buf
=
mean_square_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>&
inv_std_thread_buf
=
mean_square_thread_buf
;
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
...
...
@@ -228,6 +253,42 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
thread_k_cluster_id
*
YDstVectorSize
),
y_elementwise_op
);
auto
threadwise_mean_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
SaveMeanInvStdDataType
,
decltype
(
thread_buffer_desc_m
),
GridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
SaveMeanInvStdDstVectorSize
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
save_mean_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
auto
threadwise_inv_std_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
SaveMeanInvStdDataType
,
decltype
(
thread_buffer_desc_m
),
GridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
SaveMeanInvStdDstVectorSize
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
save_inv_std_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileStepSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
...
...
@@ -243,7 +304,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// E(x), E[x^2], var(x)
// FIXME: Should not hack the transform from deviceOP
int
reduce_length
=
x_grid_desc_m_k
.
GetTransforms
()[
I2
].
GetUpperLengths
()[
I0
];
ComputeDataType
reduce_length
=
type_convert
<
ComputeDataType
>
(
x_grid_desc_m_k
.
GetTransforms
()[
I2
].
GetUpperLengths
()[
I0
]);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
mean_thread_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
ComputeDataType
>();
...
...
@@ -302,10 +364,34 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// var(x) = E[x^2] - E[x]^2
var_thread_buf
(
I
)
=
mean_square_thread_buf
(
I
)
-
(
mean_thread_buf
(
I
)
*
mean_thread_buf
(
I
));
inv_std_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
ck
::
math
::
sqrt
(
var_thread_buf
(
I
)
+
epsilon
);
});
// save mean and inverse std for backward (optional)
if
(
thread_k_cluster_id
==
0
)
{
if
(
p_save_mean_global
!=
nullptr
)
{
threadwise_mean_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
,
save_mean_grid_desc_m
,
save_mean_global_val_buf
);
}
if
(
p_save_inv_std_global
!=
nullptr
)
{
threadwise_inv_std_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_std_thread_buf
,
save_inv_std_grid_desc_m
,
save_inv_std_global_val_buf
);
}
}
// normalization
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
...
...
@@ -314,7 +400,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
inv_std_thread_buf
(
iM
)
;
// gamma & beta
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
...
...
@@ -404,8 +490,30 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// var(x) = E[x^2] - E[x]^2
var_thread_buf
(
I
)
=
mean_square_thread_buf
(
I
)
-
(
mean_thread_buf
(
I
)
*
mean_thread_buf
(
I
));
inv_std_thread_buf
(
I
)
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
I
)
+
epsilon
);
});
if
(
thread_k_cluster_id
==
0
)
{
if
(
p_save_mean_global
!=
nullptr
)
{
threadwise_mean_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
,
save_mean_grid_desc_m
,
save_mean_global_val_buf
);
}
if
(
p_save_inv_std_global
!=
nullptr
)
{
threadwise_inv_std_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_std_thread_buf
,
save_inv_std_grid_desc_m
,
save_inv_std_global_val_buf
);
}
}
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
ThreadBufferNumber
*
thread_copy_fwd_step_m_k
;
...
...
@@ -437,7 +545,6 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
...
...
@@ -446,7 +553,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
inv_std_thread_buf
(
iM
)
;
// gamma
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp
View file @
0c823497
...
...
@@ -12,31 +12,42 @@ template <typename GridwiseReduction,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
GridDesc_M_K
>
__global__
void
kernel_normalization
(
const
GridDesc_M_K
x_grid_desc_m_k
,
typename
GridDesc_M_K
,
typename
GridDesc_M
>
__global__
void
kernel_normalization
(
const
GridDesc_M_K
x_grid_desc_m_k
,
const
GridDesc_M_K
gamma_grid_desc_m_k
,
const
GridDesc_M_K
beta_grid_desc_m_k
,
const
GridDesc_M_K
y_grid_desc_m_k
,
const
GridDesc_M
save_mean_grid_desc_m
,
const
GridDesc_M
save_inv_std_grid_desc_m
,
index_t
num_k_block_tile_iteration
,
ComputeDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_mean_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_inv_std_global
,
const
YElementwiseOperation
y_elementwise_op
)
{
GridwiseReduction
::
Run
(
x_grid_desc_m_k
,
gamma_grid_desc_m_k
,
beta_grid_desc_m_k
,
y_grid_desc_m_k
,
save_mean_grid_desc_m
,
save_inv_std_grid_desc_m
,
num_k_block_tile_iteration
,
epsilon
,
p_x_global
,
p_gamma_global
,
p_beta_global
,
p_y_global
,
p_save_mean_global
,
p_save_inv_std_global
,
y_elementwise_op
);
};
...
...
@@ -44,9 +55,11 @@ template <typename XDataType,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_M
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
...
...
@@ -60,6 +73,7 @@ template <typename XDataType,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
,
bool
UseWelford
>
auto
NormalizationKernelSelector
(
bool
isSweepOnce
)
{
...
...
@@ -68,9 +82,11 @@ auto NormalizationKernelSelector(bool isSweepOnce)
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
...
...
@@ -84,15 +100,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize
,
YDstVectorDim
,
YDstVectorSize
,
SaveMeanInvStdDstVectorSize
,
false
>
;
using
GridwiseNormalizationSweepOnceNaive
=
GridwiseNormalizationNaiveVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
...
...
@@ -106,15 +125,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize
,
YDstVectorDim
,
YDstVectorSize
,
SaveMeanInvStdDstVectorSize
,
true
>
;
using
GridwiseNormalizationGenericWelford
=
GridwiseNormalizationWelfordVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
...
...
@@ -128,15 +150,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize
,
YDstVectorDim
,
YDstVectorSize
,
SaveMeanInvStdDstVectorSize
,
false
>
;
using
GridwiseNormalizationSweepOnceWelford
=
GridwiseNormalizationWelfordVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
...
...
@@ -150,6 +175,7 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize
,
YDstVectorDim
,
YDstVectorSize
,
SaveMeanInvStdDstVectorSize
,
true
>
;
if
constexpr
(
UseWelford
)
...
...
@@ -159,17 +185,21 @@ auto NormalizationKernelSelector(bool isSweepOnce)
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
>
GridDesc_M_K
,
GridDesc_M
>
:
kernel_normalization
<
GridwiseNormalizationGenericWelford
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
>
;
GridDesc_M_K
,
GridDesc_M
>
;
}
else
{
...
...
@@ -178,17 +208,21 @@ auto NormalizationKernelSelector(bool isSweepOnce)
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
>
GridDesc_M_K
,
GridDesc_M
>
:
kernel_normalization
<
GridwiseNormalizationGenericNaive
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
>
;
GridDesc_M_K
,
GridDesc_M
>
;
}
}
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp
View file @
0c823497
...
...
@@ -17,11 +17,13 @@ template <typename MeanVarDataType,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
MeanVarGridDesc_M_KBlock
,
typename
CountGridDesc_M_KBlock
,
typename
XYGammaBetaGridDesc_M_K
,
typename
SaveMeanInvStdGridDesc_M
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
...
...
@@ -34,7 +36,8 @@ template <typename MeanVarDataType,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
>
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
>
struct
GridwiseNormalizationSplitK2nd
{
static_assert
((
XSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
...
...
@@ -45,6 +48,10 @@ struct GridwiseNormalizationSplitK2nd
(
YDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
(
MThreadSliceSize
%
SaveMeanInvStdDstVectorSize
==
0
,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!"
);
static_assert
(
XSrcVectorSize
==
YDstVectorSize
);
static_assert
(
XSrcVectorSize
==
GammaSrcVectorSize
);
static_assert
(
XSrcVectorSize
==
BetaSrcVectorSize
);
...
...
@@ -69,6 +76,10 @@ struct GridwiseNormalizationSplitK2nd
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{}));
using
ThreadBufferLengths_M
=
Sequence
<
MThreadSliceSize
>
;
static
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
static
constexpr
auto
thread_buffer_desc_m_1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
I1
));
...
...
@@ -99,6 +110,8 @@ struct GridwiseNormalizationSplitK2nd
const
XYGammaBetaGridDesc_M_K
&
gamma_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
&
beta_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
&
y_grid_desc_m_k
,
const
SaveMeanInvStdGridDesc_M
&
save_mean_grid_desc_m
,
const
SaveMeanInvStdGridDesc_M
&
save_inv_std_grid_desc_m
,
index_t
num_k_mean_var_count_iteration
,
index_t
num_k_block_tile_iteration
,
index_t
k_grid_size
,
...
...
@@ -110,6 +123,8 @@ struct GridwiseNormalizationSplitK2nd
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_mean_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_inv_std_global
,
const
YElementwiseOperation
y_elementwise_op
)
{
// Thread/Block id
...
...
@@ -145,6 +160,12 @@ struct GridwiseNormalizationSplitK2nd
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
auto
save_mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_save_mean_global
,
save_mean_grid_desc_m
.
GetElementSpaceSize
());
auto
save_inv_std_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_save_inv_std_global
,
save_inv_std_grid_desc_m
.
GetElementSpaceSize
());
// VGPR
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
in_mean_thread_buf
;
...
...
@@ -158,6 +179,7 @@ struct GridwiseNormalizationSplitK2nd
var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
welford_count_thread_buf
;
auto
&
inv_std_thread_buf
=
var_thread_buf
;
auto
x_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
...
...
@@ -283,6 +305,42 @@ struct GridwiseNormalizationSplitK2nd
thread_k_cluster_id
*
YDstVectorSize
),
y_elementwise_op
);
auto
threadwise_mean_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
SaveMeanInvStdDataType
,
decltype
(
thread_buffer_desc_m
),
SaveMeanInvStdGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
SaveMeanInvStdDstVectorSize
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
save_mean_grid_desc_m
,
make_multi_index
(
block_m_cluster_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
auto
threadwise_inv_std_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
SaveMeanInvStdDataType
,
decltype
(
thread_buffer_desc_m
),
SaveMeanInvStdGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
SaveMeanInvStdDstVectorSize
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
save_inv_std_grid_desc_m
,
make_multi_index
(
block_m_cluster_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
// step1: Merge mean and variance
constexpr
auto
mean_var_count_thread_copy_step_I0_k
=
make_multi_index
(
I0
,
KThreadClusterSize
);
...
...
@@ -332,9 +390,33 @@ struct GridwiseNormalizationSplitK2nd
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
welford_count_thread_buf
(
I
));
inv_std_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
ck
::
math
::
sqrt
(
var_thread_buf
(
I
)
+
epsilon
);
});
// step2: normalization
// step2: save mean and inverse std for backward (optional)
if
(
block_k_cluster_id
==
0
&&
thread_k_cluster_id
==
0
)
{
if
(
p_save_mean_global
!=
nullptr
)
{
threadwise_mean_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
,
save_mean_grid_desc_m
,
save_mean_global_val_buf
);
}
if
(
p_save_inv_std_global
!=
nullptr
)
{
threadwise_inv_std_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_std_thread_buf
,
save_inv_std_grid_desc_m
,
save_inv_std_global_val_buf
);
}
}
// step3: normalization
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileStepSize
);
for
(
index_t
k
=
0
;
k
<
num_k_block_tile_iteration
;
++
k
)
...
...
@@ -360,7 +442,6 @@ struct GridwiseNormalizationSplitK2nd
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
...
...
@@ -369,7 +450,7 @@ struct GridwiseNormalizationSplitK2nd
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
inv_std_thread_buf
(
iM
)
;
// gamma
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp
View file @
0c823497
...
...
@@ -16,9 +16,11 @@ template <typename XDataType,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
SaveMeanInvStdDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_M
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
...
...
@@ -32,6 +34,7 @@ template <typename XDataType,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
,
bool
SweepOnce
>
struct
GridwiseNormalizationWelfordVariance_mk_to_mk
{
...
...
@@ -43,6 +46,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
(
YDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
(
MThreadSliceSize
%
SaveMeanInvStdDstVectorSize
==
0
,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!"
);
static_assert
(
XSrcVectorSize
==
YDstVectorSize
);
static_assert
(
XSrcVectorSize
==
GammaSrcVectorSize
);
static_assert
(
XSrcVectorSize
==
BetaSrcVectorSize
);
...
...
@@ -64,6 +71,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{}));
using
ThreadBufferLengths_M
=
Sequence
<
MThreadSliceSize
>
;
static
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{})));
using
ThreadReduceDstDesc_M
=
...
...
@@ -77,6 +88,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
@@ -114,17 +127,18 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
const
GridDesc_M_K
&
gamma_grid_desc_m_k
,
const
GridDesc_M_K
&
beta_grid_desc_m_k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
const
GridDesc_M
&
save_mean_grid_desc_m
,
const
GridDesc_M
&
save_inv_std_grid_desc_m
,
index_t
num_k_block_tile_iteration
,
ComputeDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_mean_global
,
SaveMeanInvStdDataType
*
const
__restrict__
p_save_inv_std_global
,
const
YElementwiseOperation
y_elementwise_op
)
{
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
auto
x_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
...
...
@@ -150,6 +164,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
var_thread_buf
;
auto
&
inv_std_thread_buf
=
var_thread_buf
;
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
...
...
@@ -226,6 +241,42 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
thread_k_cluster_id
*
YDstVectorSize
),
y_elementwise_op
);
auto
threadwise_mean_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
SaveMeanInvStdDataType
,
decltype
(
thread_buffer_desc_m
),
GridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
SaveMeanInvStdDstVectorSize
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
save_mean_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
auto
threadwise_inv_std_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
SaveMeanInvStdDataType
,
decltype
(
thread_buffer_desc_m
),
GridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
SaveMeanInvStdDstVectorSize
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
save_inv_std_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileStepSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
...
...
@@ -239,6 +290,15 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_beta_global
,
beta_grid_desc_m_k
.
GetElementSpaceSize
());
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
auto
save_mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_save_mean_global
,
save_mean_grid_desc_m
.
GetElementSpaceSize
());
auto
save_inv_std_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_save_inv_std_global
,
save_inv_std_grid_desc_m
.
GetElementSpaceSize
());
auto
threadwise_welford
=
ThreadwiseWelford
();
threadwise_welford
.
max_count_
=
GetKPerThread
(
x_grid_desc_m_k
,
thread_k_cluster_id
);
...
...
@@ -279,10 +339,33 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
int
count
=
threadwise_welford
.
cur_count_
;
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count
);
inv_std_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
ck
::
math
::
sqrt
(
var_thread_buf
(
I
)
+
epsilon
);
});
// save mean and inverse std for backward (optional)
if
(
thread_k_cluster_id
==
0
)
{
if
(
p_save_mean_global
!=
nullptr
)
{
threadwise_mean_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
,
save_mean_grid_desc_m
,
save_mean_global_val_buf
);
}
if
(
p_save_inv_std_global
!=
nullptr
)
{
threadwise_inv_std_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_std_thread_buf
,
save_inv_std_grid_desc_m
,
save_inv_std_global_val_buf
);
}
}
// normalization
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
...
...
@@ -291,7 +374,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
inv_std_thread_buf
(
iM
)
;
// gamma & beta
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
...
...
@@ -360,8 +443,29 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
int
count
=
threadwise_welford
.
cur_count_
;
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count
);
inv_std_thread_buf
(
I
)
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
I
)
+
epsilon
);
});
if
(
thread_k_cluster_id
==
0
)
{
if
(
p_save_mean_global
!=
nullptr
)
{
threadwise_mean_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
,
save_mean_grid_desc_m
,
save_mean_global_val_buf
);
}
if
(
p_save_inv_std_global
!=
nullptr
)
{
threadwise_inv_std_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_std_thread_buf
,
save_inv_std_grid_desc_m
,
save_inv_std_global_val_buf
);
}
}
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
ThreadBufferNumber
*
thread_copy_fwd_step_m_k
;
...
...
@@ -393,7 +497,6 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
...
...
@@ -402,7 +505,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
inv_std_thread_buf
(
iM
)
;
// gamma
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
0c823497
...
...
@@ -9,6 +9,7 @@
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
#include "ck/utility/is_detected.hpp"
namespace
ck
{
...
...
@@ -211,10 +212,44 @@ struct ThreadwiseTensorSliceTransfer_v3r1
auto
src_vector_container
=
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
)};
using
dst_vector_type
=
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
dst_vector_type
op_r_v
;
constexpr
auto
get_elem_op_vec_len
=
[]()
{
if
constexpr
(
is_detected
<
is_pack8_invocable_t
,
decltype
(
src_element_op_
)
>::
value
)
{
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack8_invocable
)
return
math
::
min
(
8
,
SrcScalarPerVector
);
}
if
constexpr
(
is_detected
<
is_pack4_invocable_t
,
decltype
(
src_element_op_
)
>::
value
)
{
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack4_invocable
)
return
math
::
min
(
4
,
SrcScalarPerVector
);
}
if
constexpr
(
is_detected
<
is_pack2_invocable_t
,
decltype
(
src_element_op_
)
>::
value
)
{
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack2_invocable
)
return
math
::
min
(
2
,
SrcScalarPerVector
);
}
return
1
;
};
constexpr
index_t
elem_op_vec_len
=
get_elem_op_vec_len
();
using
src_elem_op_vec_t
=
typename
vector_type
<
SrcData
,
elem_op_vec_len
>::
type
;
using
dst_elem_op_vec_t
=
typename
vector_type
<
DstData
,
elem_op_vec_len
>::
type
;
static_for
<
0
,
SrcScalarPerVector
/
elem_op_vec_len
,
1
>
{}([
&
](
auto
idx
)
{
// apply the src elementwise op and convert to DstData under the hood if needed
src_element_op_
(
op_r_v
.
template
AsType
<
dst_elem_op_vec_t
>()(
idx
),
src_vector_container
.
template
AsType
<
src_elem_op_vec_t
>()[
idx
]);
});
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_tuple_
(
thread_scratch_id
)
.
template
SetAsType
<
src
_vector_t
>(
src_data_idx_seq
,
src_vector_container
.
template
AsType
<
src
_vector_t
>()[
I0
]);
.
template
SetAsType
<
dst
_vector_t
>(
src_data_idx_seq
,
op_r_v
.
template
AsType
<
dst
_vector_t
>()[
I0
]);
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
...
...
@@ -267,19 +302,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
{
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
// convert from SrcData to DstData here
dst_thread_scratch_
(
idx
)
=
type_convert
<
DstData
>
(
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
]);
dst_thread_scratch_
(
idx
)
=
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
];
});
#else
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype
if
constexpr
(
SrcVectorDim
!=
DstVectorDim
&&
((
is_same
<
half_t
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
half_t
,
remove_cvref_t
<
DstData
>>::
value
&&
((
is_same
<
half_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
2
==
0
&&
DstScalarPerVector
%
2
==
0
)
||
(
is_same
<
int8_t
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
int8_t
,
remove_cvref_t
<
DstData
>>::
value
&&
(
is_same
<
int8_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
4
==
0
&&
DstScalarPerVector
%
4
==
0
)))
{
// each transpose does
...
...
@@ -313,7 +344,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr
auto
data_idx_seq
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
return
Number
<
data_idx
[
i
]
>
{};
},
Number
<
nDim
>
{});
using
src_vector_t
=
vector_type_maker_t
<
Src
Data
,
SrcScalarPerVector
>
;
using
src_vector_t
=
vector_type_maker_t
<
Dst
Data
,
SrcScalarPerVector
>
;
using
dst_vector_t
=
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
;
// get DstScalarPerVector # of read-only references to src vectors from
...
...
@@ -336,17 +367,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
Number
<
num_dst_vector
>
{});
// do data transpose
transpose_vectors
<
Src
Data
,
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
transpose_vectors
<
Dst
Data
,
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
src_vector_refs
,
dst_vector_refs
);
});
}
else
{
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
// apply the src elementwise op and convert to DstData under the hood if needed
DstData
dst_v
;
src_element_op_
(
dst_v
,
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
]);
dst_thread_scratch_
(
idx
)
=
dst_v
;
dst_thread_scratch_
(
idx
)
=
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
];
});
}
#endif
}
...
...
@@ -761,8 +791,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static
constexpr
auto
src_thread_scratch_desc_
=
decltype
(
GetSrcThreadScratchDescriptor
()){};
static
constexpr
auto
dst_thread_scratch_desc_
=
decltype
(
GetDstThreadScratchDescriptor
()){};
using
SrcThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
SrcData
,
using
SrcThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
DstData
,
// apply data_convert with SrcThreadScratch
SrcScalarPerVector
,
decltype
(
src_thread_scratch_desc_
),
true
>
;
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp
View file @
0c823497
...
...
@@ -132,9 +132,6 @@ struct ThreadwiseTensorSliceTransfer_v7r2
Number
<
num
>
{});
}
template
<
typename
T
>
using
has_vec_len
=
decltype
(
std
::
declval
<
T
&>
().
vec_len
);
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
template
<
typename
SrcBuffers
,
...
...
@@ -159,19 +156,26 @@ struct ThreadwiseTensorSliceTransfer_v7r2
is_src_valid
);
});
if
constexpr
(
is_detected
<
has_vec_len
,
decltype
(
element_op_
)
>::
value
)
constexpr
auto
get_elem_op_vec_len
=
[]()
{
if
constexpr
(
is_detected
<
is_pack8_invocable_t
,
decltype
(
element_op_
)
>::
value
)
{
constexpr
auto
elem_op_vec_len
=
decltype
(
element_op_
)
::
vec_len
;
static_assert
(
is_same
<
remove_cvref_t
<
decltype
(
elem_op_vec_len
)
>
,
index_t
>::
value
,
"vec_len in element_op_ type is not index_t"
);
static_assert
(
elem_op_vec_len
==
1
||
elem_op_vec_len
==
2
||
elem_op_vec_len
==
4
||
elem_op_vec_len
==
8
,
"vec_len in element_op_ must be 1, 2, 4, 8"
);
if
constexpr
(
decltype
(
element_op_
)
::
is_pack8_invocable
)
return
math
::
min
(
8
,
SrcScalarPerVector
);
}
if
constexpr
(
is_detected
<
is_pack4_invocable_t
,
decltype
(
element_op_
)
>::
value
)
{
if
constexpr
(
decltype
(
element_op_
)
::
is_pack4_invocable
)
return
math
::
min
(
4
,
SrcScalarPerVector
);
}
if
constexpr
(
is_detected
<
is_pack2_invocable_t
,
decltype
(
element_op_
)
>::
value
)
{
if
constexpr
(
decltype
(
element_op_
)
::
is_pack2_invocable
)
return
math
::
min
(
2
,
SrcScalarPerVector
);
}
return
1
;
};
static_assert
(
SrcScalarPerVector
%
elem_op_vec_len
==
0
,
"vec_len in element_op_ cannot be divided by SrcScalarPerVector!"
);
constexpr
index_t
elem_op_vec_len
=
get_elem_op_vec_len
();
// apply pointwise function
static_for
<
0
,
SrcScalarPerVector
/
elem_op_vec_len
,
1
>
{}([
&
](
auto
i
)
{
...
...
@@ -181,8 +185,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
[
&
](
auto
iSrc
)
->
const
auto
&
{
using
SrcData
=
remove_cvref_t
<
tuple_element_t
<
iSrc
.
value
,
SrcDatas
>>
;
using
elem_op_vec_t
=
typename
vector_type
<
SrcData
,
elem_op_vec_len
>::
type
;
using
elem_op_vec_t
=
typename
vector_type
<
SrcData
,
elem_op_vec_len
>::
type
;
return
src_vectors
[
iSrc
].
template
AsType
<
elem_op_vec_t
>()[
i
];
},
...
...
@@ -194,8 +197,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
[
&
](
auto
iDst
)
->
auto
&
{
using
DstData
=
remove_cvref_t
<
tuple_element_t
<
iDst
.
value
,
DstDatas
>>
;
using
elem_op_vec_t
=
typename
vector_type
<
DstData
,
elem_op_vec_len
>::
type
;
using
elem_op_vec_t
=
typename
vector_type
<
DstData
,
elem_op_vec_len
>::
type
;
return
dst_vectors
(
iDst
).
template
AsType
<
elem_op_vec_t
>()(
i
);
},
...
...
@@ -211,42 +213,6 @@ struct ThreadwiseTensorSliceTransfer_v7r2
// ...)
unpack2
(
element_op_
,
dst_data_refs
,
src_data_refs
);
});
}
else
{
// apply pointwise function
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
// get reference to src data
const
auto
src_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iSrc
)
->
const
auto
&
{
using
SrcData
=
remove_cvref_t
<
tuple_element_t
<
iSrc
.
value
,
SrcDatas
>>
;
return
src_vectors
[
iSrc
].
template
AsType
<
SrcData
>()[
i
];
},
Number
<
nSrc
>
{});
// get reference to dst data
auto
dst_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iDst
)
->
auto
&
{
using
DstData
=
remove_cvref_t
<
tuple_element_t
<
iDst
.
value
,
DstDatas
>>
;
return
dst_vectors
(
iDst
).
template
AsType
<
DstData
>()(
i
);
},
Number
<
nDst
>
{});
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2
(
element_op_
,
dst_data_refs
,
src_data_refs
);
});
}
dst_vectors_tuple_
(
iAccess
)
=
dst_vectors
;
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
0c823497
...
...
@@ -462,7 +462,6 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
}
};
#if defined CK_ENABLE_FP8
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8f8
>
{
...
...
@@ -506,9 +505,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
intrin_mfma_f32_16x16x32f8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16bf8bf8
>
{
...
...
@@ -552,9 +549,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8bf8>
intrin_mfma_f32_16x16x32bf8bf8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8bf8
>
{
...
...
@@ -598,9 +593,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8bf8>
intrin_mfma_f32_16x16x32f8bf8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16bf8f8
>
{
...
...
@@ -644,7 +637,6 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
intrin_mfma_f32_16x16x32bf8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
#endif
template
<
typename
base_type
,
index_t
MPerXdlops
,
...
...
@@ -792,7 +784,6 @@ struct MfmaSelector
}
#endif
#if defined CK_ENABLE_FP8
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
{
...
...
@@ -804,9 +795,7 @@ struct MfmaSelector
{
return
MfmaInstr
::
mfma_f32_16x16x32f8f8
;
}
#endif
#if defined CK_ENABLE_BF8
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
>
()
{
...
...
@@ -818,9 +807,7 @@ struct MfmaSelector
{
return
MfmaInstr
::
mfma_f32_16x16x32bf8bf8
;
}
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
,
bf8_t
>
()
{
...
...
@@ -832,9 +819,7 @@ struct MfmaSelector
{
return
MfmaInstr
::
mfma_f32_16x16x32f8bf8
;
}
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
,
f8_t
>
()
{
...
...
@@ -846,7 +831,6 @@ struct MfmaSelector
{
return
MfmaInstr
::
mfma_f32_16x16x32bf8f8
;
}
#endif
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
()
>
{};
...
...
@@ -1051,18 +1035,10 @@ struct XdlopsGemm
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
int8_t
>::
value
#if defined CK_ENABLE_FP8
||
is_same
<
base_type
,
f8_t
>::
value
#endif
#if defined CK_ENABLE_BF8
||
is_same
<
base_type
,
bf8_t
>::
value
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
||
(
is_same
<
base_type
,
f8_t
>::
value
&&
is_same
<
additional_type
,
bf8_t
>::
value
)
||
(
is_same
<
base_type
,
bf8_t
>::
value
&&
is_same
<
additional_type
,
f8_t
>::
value
)
#endif
,
is_same
<
base_type
,
int8_t
>::
value
||
is_same
<
base_type
,
f8_t
>::
value
||
is_same
<
base_type
,
bf8_t
>::
value
||
(
is_same
<
base_type
,
f8_t
>::
value
&&
is_same
<
additional_type
,
bf8_t
>::
value
)
||
(
is_same
<
base_type
,
bf8_t
>::
value
&&
is_same
<
additional_type
,
f8_t
>::
value
),
"base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!"
);
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
0c823497
...
...
@@ -299,368 +299,146 @@ enum struct AmdBufferCoherenceEnum
GLC_SLC
=
3
,
};
template
<
typename
T
,
index_t
N
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
__device__
typename
vector_type
<
T
,
N
>::
type
amd_buffer_load_impl
(
int32x4_t
src_wave_buffer_resource
,
template
<
index_t
N
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
__device__
typename
vector_type
<
int8_t
,
N
>::
type
amd_buffer_load_impl_raw
(
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
)
{
static_assert
(
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
bhalf_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
static_assert
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
||
N
==
32
||
N
==
64
,
"wrong! not implemented"
);
if
constexpr
(
is_same
<
T
,
double
>::
value
)
{
// use fp32 load to mimic fp64 load
if
constexpr
(
N
==
1
)
{
const
float2_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x2
(
src_wave_buffer_resource
,
return
llvm_amdgcn_raw_buffer_load_i8
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
double
>
(
tmp
);
}
else
if
constexpr
(
N
==
2
)
{
const
float4_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
double2_t
>
(
tmp
);
}
else
if
constexpr
(
N
==
4
)
{
const
float4_t
f32_0
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
int16_t
tmp
=
llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
const
float4_t
f32_1
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
float
),
static_cast
<
index_t
>
(
coherence
));
vector_type
<
double
,
4
>
tmp
;
tmp
.
AsType
<
double2_t
>
()(
Number
<
0
>
{})
=
bit_cast
<
double2_t
>
(
f32_0
);
tmp
.
AsType
<
double2_t
>
()(
Number
<
1
>
{})
=
bit_cast
<
double2_t
>
(
f32_1
);
return
tmp
.
AsType
<
double4_t
>
()(
Number
<
0
>
{});
}
}
else
if
constexpr
(
is_same
<
T
,
float
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_fp32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_fp32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
int8x2_t
>
(
tmp
);
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
float
,
8
>
tmp
;
tmp
.
AsType
<
float4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
int32_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
float4_t
>
()(
Number
<
1
>
{})
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
float
),
static_cast
<
index_t
>
(
coherence
));
return
tmp
.
AsType
<
float8_t
>
()(
Number
<
0
>
{});
}
}
else
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_fp16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_fp16x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_fp16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
int8x4_t
>
(
tmp
);
}
else
if
constexpr
(
N
==
8
)
{
// use fp32 load to mimic fp16 load
float4_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
int32x2_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
half8_t
>
(
tmp
);
}
}
else
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_i16x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_i16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
int8x8_t
>
(
tmp
);
}
else
if
constexpr
(
N
==
8
)
else
if
constexpr
(
N
==
16
)
{
int32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
bhalf8_t
>
(
tmp
);
}
}
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_i32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
return
llvm_amdgcn_raw_buffer_load_i32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
return
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
int8x16_t
>
(
tmp
);
}
else
if
constexpr
(
N
==
8
)
else
if
constexpr
(
N
==
32
)
{
vector_type
<
int32_t
,
8
>
tmp
;
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
int32x4_t
tmp0
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
1
>
{})
=
int32x4_t
tmp1
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int32_t
),
static_cast
<
index_t
>
(
coherence
));
return
tmp
.
AsType
<
int32x8_t
>
()(
Number
<
0
>
{});
}
}
else
if
constexpr
(
is_same
<
T
,
int8_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
return
llvm_amdgcn_raw_buffer_load_i8
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return
llvm_amdgcn_raw_buffer_load_i8x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#else
int16_t
tmp
=
llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
vector_type
<
int32_t
,
8
>
tmp
;
return
bit_cast
<
int8x2_t
>
(
tmp
);
#endif
}
else
if
constexpr
(
N
==
4
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#else
int32_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
0
>
{})
=
tmp0
;
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
1
>
{})
=
tmp1
;
return
bit_cast
<
int8x4_t
>
(
tmp
);
#endif
return
bit_cast
<
int8x32_t
>
(
tmp
);
}
else
if
constexpr
(
N
==
8
)
else
if
constexpr
(
N
==
64
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type
<
int8_t
,
8
>
tmp
;
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
int32x4_t
tmp0
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
1
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int8_t
),
static_cast
<
index_t
>
(
coherence
));
return
tmp
.
AsType
<
int8x8_t
>
()(
Number
<
0
>
{});
#else
int32x2_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
int8x8_t
>
(
tmp
);
#endif
}
else
if
constexpr
(
N
==
16
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type
<
int8_t
,
16
>
tmp
;
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
int32x4_t
tmp1
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int32_t
)
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
1
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
int32x4_t
tmp2
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int
8
_t
),
src_wave_addr_offset
+
8
*
sizeof
(
int
32
_t
),
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
2
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
int32x4_t
tmp3
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
8
*
sizeof
(
int
8
_t
),
src_wave_addr_offset
+
12
*
sizeof
(
int
32
_t
),
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int8x4_t
>
()(
Number
<
3
>
{})
=
llvm_amdgcn_raw_buffer_load_i8x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
12
*
sizeof
(
int8_t
),
static_cast
<
index_t
>
(
coherence
));
vector_type
<
int32_t
,
16
>
tmp
;
return
tmp
.
AsType
<
int8x16_t
>
()(
Number
<
0
>
{});
#else
int32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
0
>
{})
=
tmp0
;
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
1
>
{})
=
tmp1
;
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
2
>
{})
=
tmp2
;
tmp
.
AsType
<
int32x4_t
>
()(
Number
<
3
>
{})
=
tmp3
;
return
bit_cast
<
int8x16_t
>
(
tmp
);
#endif
}
return
bit_cast
<
int8x64_t
>
(
tmp
);
}
}
template
<
typename
T
,
index_t
N
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
__device__
void
amd_buffer_store_impl
(
const
typename
vector_type
<
T
,
N
>::
type
src_thread_data
,
__device__
typename
vector_type
<
T
,
N
>::
type
amd_buffer_load_impl
(
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
)
{
static_assert
(
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bhalf_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
f8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
using
r_t
=
typename
vector_type
<
T
,
N
>::
type
;
auto
raw_data
=
amd_buffer_load_impl_raw
<
sizeof
(
T
)
*
N
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
);
return
bit_cast
<
r_t
>
(
raw_data
);
}
template
<
index_t
N
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
__device__
void
amd_buffer_store_impl_raw
(
const
typename
vector_type
<
int8_t
,
N
>::
type
src_thread_data
,
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
index_t
dst_wave_addr_offset
)
{
static_assert
(
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
bhalf_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
static_assert
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
||
N
==
32
||
N
==
64
,
"wrong! not implemented"
);
if
constexpr
(
is_same
<
T
,
double
>::
value
)
{
// use fp32 store to mimic fp64 store
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_fp32x2
(
bit_cast
<
float2_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp32x4
(
bit_cast
<
float4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
}
else
if
constexpr
(
is_same
<
T
,
float
>::
value
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_
fp32
(
src_thread_data
,
llvm_amdgcn_raw_buffer_store_
i8
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
...
...
@@ -668,7 +446,8 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp32x2
(
src_thread_data
,
llvm_amdgcn_raw_buffer_store_i16
(
bit_cast
<
int16_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
...
...
@@ -676,7 +455,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_
fp32x4
(
src_thread_data
,
llvm_amdgcn_raw_buffer_store_
i32
(
bit_cast
<
int32_t
>
(
src_thread_data
)
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
...
...
@@ -684,199 +463,91 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
float
,
8
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_store_fp32x4
(
tmp
.
AsType
<
float4_t
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_fp32x4
(
tmp
.
AsType
<
float4_t
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
float
),
static_cast
<
index_t
>
(
coherence
));
}
}
else
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_fp16
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp16x2
(
src_thread_data
,
llvm_amdgcn_raw_buffer_store_i32x2
(
bit_cast
<
int32x2_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
else
if
constexpr
(
N
==
16
)
{
llvm_amdgcn_raw_buffer_store_
fp16x4
(
src_thread_data
,
llvm_amdgcn_raw_buffer_store_
i32x4
(
bit_cast
<
int32x4_t
>
(
src_thread_data
)
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
else
if
constexpr
(
N
==
32
)
{
#if 0
vector_type<half_t, 8> tmp{src_thread_data};
vector_type
<
int32_t
,
8
>
tmp
{
bit_cast
<
int32x8_t
>
(
src_thread_data
)};
llvm_amdgcn_raw_buffer_store_
fp16
x4(tmp.AsType<
half
4_t>()[Number<0>{}],
llvm_amdgcn_raw_buffer_store_
i32
x4
(
tmp
.
template
AsType
<
int32x
4_t
>()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_
fp16
x4(tmp.AsType<
half
4_t>()[Number<1>{}],
llvm_amdgcn_raw_buffer_store_
i32
x4
(
tmp
.
template
AsType
<
int32x
4_t
>()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset +
4 *
sizeof(
half_t)
,
dst_wave_addr_offset
+
sizeof
(
int32_t
)
*
4
,
static_cast
<
index_t
>
(
coherence
));
#else
llvm_amdgcn_raw_buffer_store_fp32x4
(
bit_cast
<
float4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#endif
}
}
else
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
else
if
constexpr
(
N
==
64
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_i16
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_i16x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_i16x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
vector_type
<
bhalf_t
,
8
>
tmp
{
src_thread_data
};
vector_type
<
int32_t
,
16
>
tmp
{
bit_cast
<
int32x16_t
>
(
src_thread_data
)};
llvm_amdgcn_raw_buffer_store_i
16
x4
(
tmp
.
AsType
<
bhalf
4_t
>
()[
Number
<
0
>
{}],
llvm_amdgcn_raw_buffer_store_i
32
x4
(
tmp
.
template
AsType
<
int32x
4_t
>()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_i16x4
(
tmp
.
AsType
<
bhalf4_t
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
bhalf_t
),
static_cast
<
index_t
>
(
coherence
));
}
}
else
if
constexpr
(
is_same
<
T
,
int32_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_i32
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_i32x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_i32x4
(
src_thread_data
,
llvm_amdgcn_raw_buffer_store_i32x4
(
tmp
.
template
AsType
<
int32x4_t
>()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
int32_t
)
*
4
,
static_cast
<
index_t
>
(
coherence
));
}
}
else
if
constexpr
(
is_same
<
T
,
int8_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_i8
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
llvm_amdgcn_raw_buffer_store_i8x2
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#else
llvm_amdgcn_raw_buffer_store_i16
(
bit_cast
<
int16_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#endif
}
else
if
constexpr
(
N
==
4
)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
llvm_amdgcn_raw_buffer_store_i8x4
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#else
llvm_amdgcn_raw_buffer_store_i32
(
bit_cast
<
int32_t
>
(
src_thread_data
),
llvm_amdgcn_raw_buffer_store_i32x4
(
tmp
.
template
AsType
<
int32x4_t
>()[
Number
<
2
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
int32_t
)
*
8
,
static_cast
<
index_t
>
(
coherence
));
#endif
}
else
if
constexpr
(
N
==
8
)
{
llvm_amdgcn_raw_buffer_store_i32x2
(
bit_cast
<
int32x2_t
>
(
src_thread_data
),
llvm_amdgcn_raw_buffer_store_i32x4
(
tmp
.
template
AsType
<
int32x4_t
>()[
Number
<
3
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
int32_t
)
*
12
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
16
)
{
llvm_amdgcn_raw_buffer_store_i32x4
(
bit_cast
<
int32x4_t
>
(
src_thread_data
),
}
template
<
typename
T
,
index_t
N
,
AmdBufferCoherenceEnum
coherence
=
AmdBufferCoherenceEnum
::
DefaultCoherence
>
__device__
void
amd_buffer_store_impl
(
const
typename
vector_type
<
T
,
N
>::
type
src_thread_data
,
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
index_t
dst_wave_addr_offset
)
{
static_assert
(
(
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bhalf_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
f8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
using
r_t
=
typename
vector_type
<
int8_t
,
sizeof
(
T
)
*
N
>::
type
;
amd_buffer_store_impl_raw
<
sizeof
(
T
)
*
N
,
coherence
>
(
bit_cast
<
r_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
}
dst_wave_addr_offset
);
}
template
<
typename
T
,
index_t
N
>
...
...
@@ -1127,54 +798,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_element_valid
?
0
:
0x80000000
;
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
||
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
#endif
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
return
bit_cast
<
vector_t
>
(
tmp
);
}
else
{
#endif
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
#else
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
||
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
#endif
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
bit_cast
<
vector_t
>
(
tmp
)
:
vector_t
(
0
);
}
else
{
#endif
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
tmp
:
vector_t
(
0
);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
#endif
}
...
...
@@ -1232,62 +863,13 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x80000000
;
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
||
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
#endif
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{
auto
tmp
=
bit_cast
<
typename
vector_type_maker
<
int8_t
,
vector_size
>::
type
::
type
>
(
src_thread_data
);
amd_buffer_store_impl
<
int8_t
,
vector_size
,
coherence
>
(
tmp
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
}
else
{
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#else
if
(
dst_thread_element_valid
)
{
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
||
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
#endif
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
if
constexpr
(
is_same
<
scalar_t
,
bf8_t
>::
value
)
#endif
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
{
auto
tmp
=
bit_cast
<
typename
vector_type_maker
<
int8_t
,
vector_size
>::
type
::
type
>
(
src_thread_data
);
amd_buffer_store_impl
<
int8_t
,
vector_size
,
coherence
>
(
tmp
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
else
{
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
}
#endif
}
...
...
include/ck/utility/amd_xdlops.hpp
View file @
0c823497
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP
#include "data_type.hpp"
#pragma once
namespace
ck
{
...
...
@@ -355,7 +352,6 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
}
};
#if defined CK_ENABLE_FP8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f8f8
;
...
...
@@ -418,9 +414,7 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif
}
};
#endif
#if defined CK_ENABLE_BF8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16bf8bf8
;
...
...
@@ -483,9 +477,7 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
#endif
}
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f8bf8
;
...
...
@@ -548,9 +540,7 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
#endif
}
};
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16bf8f8
;
...
...
@@ -613,6 +603,5 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
#endif
}
};
#endif
}
// namespace ck
#endif
include/ck/utility/data_type.hpp
View file @
0c823497
...
...
@@ -9,15 +9,9 @@ namespace ck {
using
bhalf_t
=
ushort
;
using
half_t
=
_Float16
;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using
int4_t
=
_BitInt
(
4
);
#endif
#if defined CK_ENABLE_FP8
using
f8_t
=
_BitInt
(
8
);
#endif
#if defined CK_ENABLE_BF8
using
bf8_t
=
unsigned
_BitInt
(
8
);
#endif
// vector_type
template
<
typename
T
,
index_t
N
>
...
...
@@ -148,23 +142,19 @@ struct scalar_type<int4_t>
};
#endif
#if defined CK_ENABLE_FP8
template
<
>
struct
scalar_type
<
f8_t
>
{
using
type
=
f8_t
;
static
constexpr
index_t
vector_size
=
1
;
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
struct
scalar_type
<
bf8_t
>
{
using
type
=
bf8_t
;
static
constexpr
index_t
vector_size
=
1
;
};
#endif
template
<
typename
T
>
struct
vector_type
<
T
,
1
>
...
...
@@ -968,24 +958,20 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
// f8
#if defined CK_ENABLE_FP8
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
#endif
// bf8
#if defined CK_ENABLE_BF8
using
bf8x2_t
=
typename
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x4_t
=
typename
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x8_t
=
typename
vector_type
<
bf8_t
,
8
>::
type
;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
#endif
template
<
typename
T
>
struct
NumericLimits
...
...
@@ -1033,7 +1019,6 @@ struct NumericLimits<int4_t>
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#if defined CK_ENABLE_FP8
template
<
>
struct
NumericLimits
<
f8_t
>
{
...
...
@@ -1056,9 +1041,7 @@ struct NumericLimits<f8_t>
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
f8_t
(
binary_qnan
);
}
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
struct
NumericLimits
<
bf8_t
>
{
...
...
@@ -1081,7 +1064,6 @@ struct NumericLimits<bf8_t>
__host__
__device__
static
constexpr
bf8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
};
#endif
template
<
typename
T
>
struct
NumericUtils
...
...
@@ -1093,6 +1075,7 @@ struct NumericUtils<float>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
23
;
static
constexpr
int
bias
=
127
;
static
constexpr
uint32_t
nan_mask
=
0x7F800000
;
static
constexpr
uint32_t
head_mask
=
0xFF800000
;
static
constexpr
uint32_t
mant_mask
=
0x7FFFFF
;
...
...
@@ -1109,6 +1092,7 @@ struct NumericUtils<half_t>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
10
;
static
constexpr
int
bias
=
15
;
static
constexpr
uint16_t
nan_mask
=
0x7C00
;
static
constexpr
uint16_t
head_mask
=
0xFC00
;
static
constexpr
uint16_t
mant_mask
=
0x3FF
;
...
...
@@ -1120,22 +1104,21 @@ struct NumericUtils<half_t>
using
bitwise_type
=
uint16_t
;
};
#if defined CK_ENABLE_FP8
template
<
>
struct
NumericUtils
<
f8_t
>
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
static
constexpr
int
bias
=
8
;
// negative zero nan mode
// static constexpr int bias = 7; // ieee mode
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
struct
NumericUtils
<
bf8_t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
bias
=
16
;
// negative zero nan mode
// static constexpr int bias = 15; // ieee mode
};
#endif
}
// namespace ck
include/ck/utility/f8_utils.hpp
View file @
0c823497
...
...
@@ -5,9 +5,6 @@
#include "ck/utility/data_type.hpp"
// these conversions are disabled if native conversions available
#if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
namespace
ck
{
// fp8 rounding modes
...
...
@@ -19,6 +16,9 @@ enum class f8_rounding_mode
stochastic
};
__host__
inline
int
clz
(
uint32_t
x
)
{
return
__builtin_clz
(
x
);
}
__device__
inline
int
clz
(
uint32_t
x
)
{
return
__clz
(
x
);
}
}
// namespace ck
namespace
ck
::
utils
{
...
...
@@ -36,7 +36,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
constexpr
int
in_exp
=
NumericUtils
<
X
>::
exp
;
constexpr
int
in_mant
=
NumericUtils
<
X
>::
mant
;
int
exponent
;
int
exponent
,
bias
;
uint32_t
head
,
mantissa
,
sign
;
// nan code is same for float and half
constexpr
Y
nan_code
=
0x80
;
...
...
@@ -51,12 +51,11 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
mantissa
=
x_bitwise
&
NumericUtils
<
X
>::
mant_mask
;
exponent
=
(
head
>>
in_mant
)
&
NumericUtils
<
X
>::
exp_mask
;
sign
=
head
>>
(
in_exp
+
in_mant
);
bias
=
NumericUtils
<
X
>::
bias
;
uint32_t
signed_inf
=
(
sign
<<
(
in_exp
+
in_mant
))
+
(((
1
<<
in_exp
)
-
1
)
<<
in_mant
);
uint32_t
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
))
-
1
;
constexpr
int
max_exp
=
(
1
<<
out_exp
)
-
(
negative_zero_nan
?
1
:
2
);
constexpr
int
exp_low_cutoff
=
(
1
<<
(
in_exp
-
1
))
-
(
1
<<
(
out_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
if
constexpr
(
negative_zero_nan
)
{
...
...
@@ -69,56 +68,107 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
// if input is half and output is bf8
if
((
NumericUtils
<
X
>::
mant
==
10
)
&&
(
NumericUtils
<
Y
>::
mant
==
2
)
&&
negative_zero_nan
&&
exponent
==
0
)
{
exponent
+=
1
;
while
(
mantissa
<
(
1
<<
in_mant
))
{
mantissa
<<=
1
;
exponent
-=
1
;
}
mantissa
&=
~
(
1
<<
in_mant
);
}
// check if x is 0.0
if
(
x_bitwise
==
0
)
return
0
;
exponent
-=
exp_low_cutoff
-
1
;
if
(
exponent
<=
0
)
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
+
1
-
exponent
))
-
1
;
mantissa
+=
1
<<
in_mant
;
// apply random number if needed
mantissa
+=
(
stoch
?
rng
:
mantissa
)
&
drop_mask
;
if
(
mantissa
>=
(
2
<<
in_mant
))
// First need to check if it is normal or denorm as there is a difference of implict 1
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
// exponent and mantissa again3
// For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits
const
int
out_bias
=
(
1
<<
(
out_exp
-
1
))
-
1
+
(
negative_zero_nan
?
1
:
0
);
const
int
out_denormal_act_exponent
=
1
-
out_bias
;
// actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// out_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int
act_exponent
,
out_exponent
,
exponent_diff
;
if
(
exponent
==
0
)
{
// fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
In this case, the fp16 mantissa should be shift left by 1 */
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
out_denormal_act_exponent
-
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
}
else
{
// fp32/fp16 is normal with implicit 1
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
out_denormal_act_exponent
)
{
mantissa
>>=
1
;
exponent
++
;
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff
=
out_denormal_act_exponent
-
act_exponent
;
}
else
{
// both fp32/fp16 and f8 are in normal range
exponent_diff
=
0
;
// exponent_diff=0 does not mean there is no difference for this case,
// act_exponent could be larger. Just that it does not need shift mantissa
}
mantissa
+=
(
1
<<
in_mant
);
// Add the implicit 1 into mantissa
}
mantissa
>>=
(
in_mant
-
out_mant
);
// check negative exponent
if
(
exponent
<=
0
)
bool
midpoint
=
(
mantissa
&
((
1
<<
(
in_mant
-
out_mant
+
exponent_diff
))
-
1
))
==
(
1
<<
(
in_mant
-
out_mant
+
exponent_diff
-
1
));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
shift right as shift right could rip off some residual part and make something not midpoint look
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
midpoint, but after shift right by 4 bits, it would look like midpoint. */
if
(
exponent_diff
>
0
)
mantissa
>>=
exponent_diff
;
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
bool
implicit_one
=
mantissa
&
(
1
<<
in_mant
);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
out_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
out_bias
-
(
implicit_one
?
0
:
1
);
// Now we have the exponent and mantissa adjusted
bool
odd
=
mantissa
&
(
1
<<
(
in_mant
-
out_mant
));
// if the least significant bit that is not truncated is 1
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
))
&
drop_mask
;
// Now we deal with overflow
if
(
out_exponent
==
0
)
{
if
(
x_bitwise
==
0
)
return
0
;
if
((
1
<<
in_mant
)
&
mantissa
)
{
out_exponent
=
1
;
// denormal overflow to become normal, promote exponent
// No need to make 1 implicit now as it will be addressed later
}
}
else
{
// subnormal range; represented by a subnormal float8 (exponent 0)
// and involves loss of accuracy
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
if
((
1
<<
(
in_mant
+
1
))
&
mantissa
)
{
mantissa
>>=
1
;
out_exponent
++
;
// No need to make 1 implicit now as it will be addressed later
}
}
// above range: quantize to maximum possible float of the same sign
else
if
(
exponent
>
max_exp
)
mantissa
>>=
(
in_mant
-
out_mant
);
if
(
out_exponent
>
max_exp
)
{
if
(
clip
)
{
mantissa
=
(
1
<<
out_mant
)
-
1
;
exponent
=
max_exp
;
out_
exponent
=
max_exp
;
}
else
{
...
...
@@ -127,10 +177,10 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
}
// check if x is 0.0 or -0.0
if
(
exponent
==
0
&&
mantissa
==
0
)
if
(
out_
exponent
==
0
&&
mantissa
==
0
)
return
negative_zero_nan
?
0
:
(
sign
<<
(
out_exp
+
out_mant
));
mantissa
&=
(
1
<<
out_mant
)
-
1
;
return
(
sign
<<
(
out_exp
+
out_mant
))
|
(
exponent
<<
out_mant
)
|
mantissa
;
return
(
sign
<<
(
out_exp
+
out_mant
))
|
(
out_
exponent
<<
out_mant
)
|
mantissa
;
}
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
...
...
@@ -196,12 +246,9 @@ __host__ __device__ Y run_cast_from_f8(X x)
if
(
exponent
==
0
)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
exponent
++
;
while
(
mantissa
<
(
1
<<
in_mant
))
{
mantissa
<<=
1
;
exponent
--
;
}
int
sh
=
1
+
clz
(
mantissa
)
-
(
32
-
in_mant
);
mantissa
<<=
sh
;
exponent
+=
1
-
sh
;
mantissa
&=
((
1
<<
in_mant
)
-
1
);
}
exponent
+=
exp_low_cutoff
-
1
;
...
...
@@ -244,5 +291,3 @@ __host__ __device__ Y cast_from_f8(X x)
}
}
// namespace ck::utils
#endif // #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
#endif // #if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
include/ck/utility/inner_product.hpp
View file @
0c823497
...
...
@@ -192,6 +192,8 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
#else
c
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b
),
c
,
false
);
#endif
#elif defined(CK_USE_AMD_V_DOT4_I32_I8_GFX11)
c
=
__builtin_amdgcn_sudot4
(
true
,
bit_cast
<
int32_t
>
(
a
),
true
,
bit_cast
<
int32_t
>
(
b
),
c
,
false
);
#else
const
vector_type
<
int8_t
,
4
>
a_vector
{
a
};
const
vector_type
<
int8_t
,
4
>
b_vector
{
b
};
...
...
include/ck/utility/is_detected.hpp
View file @
0c823497
...
...
@@ -31,4 +31,13 @@ struct nonesuch
template
<
template
<
class
...
>
class
Op
,
class
...
Args
>
using
is_detected
=
typename
detail
::
detector
<
nonesuch
,
void
,
Op
,
Args
...
>::
value_t
;
template
<
typename
T
>
using
is_pack2_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack2_invocable
);
template
<
typename
T
>
using
is_pack4_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack4_invocable
);
template
<
typename
T
>
using
is_pack8_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack8_invocable
);
}
// namespace ck
Prev
1
…
4
5
6
7
8
9
10
11
12
…
21
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment