Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a4522ae3
Commit
a4522ae3
authored
Nov 06, 2024
by
illsilin
Browse files
sync from public repo
parents
1f127242
e0594d08
Changes
425
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2097 additions
and
394 deletions
+2097
-394
example/ck_tile/12_smoothquant/script/smoke_test.sh
example/ck_tile/12_smoothquant/script/smoke_test.sh
+30
-0
example/ck_tile/12_smoothquant/smoothquant.cpp
example/ck_tile/12_smoothquant/smoothquant.cpp
+218
-0
example/ck_tile/12_smoothquant/smoothquant.hpp
example/ck_tile/12_smoothquant/smoothquant.hpp
+114
-0
example/ck_tile/CMakeLists.txt
example/ck_tile/CMakeLists.txt
+6
-0
include/ck/host_utility/flush_cache.hpp
include/ck/host_utility/flush_cache.hpp
+38
-17
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+2
-2
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
+3
-3
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
...ation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
+17
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
...vice_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
+6
-6
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+12
-12
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
.../device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
+6
-6
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...mpl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+16
-3
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+24
-4
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+801
-150
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+12
-12
include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
...k/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
+4
-4
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+6
-6
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+767
-163
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+9
-5
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+6
-0
No files found.
example/ck_tile/12_smoothquant/script/smoke_test.sh
0 → 100755
View file @
a4522ae3
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_smoothquant
-type
f |
head
-n
1
)
"
for
pr_i
in
"fp16"
"bf16"
;
do
$EXE
-prec
=
$pr_i
-m
=
99
-n
=
13
$EXE
-prec
=
$pr_i
-m
=
17
-n
=
16
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
100
$EXE
-prec
=
$pr_i
-m
=
4
-n
=
128
$EXE
-prec
=
$pr_i
-m
=
80
-n
=
127
$EXE
-prec
=
$pr_i
-m
=
22
-n
=
255
-stride
=
256
$EXE
-prec
=
$pr_i
-m
=
7
-n
=
599
$EXE
-prec
=
$pr_i
-m
=
19
-n
=
512
$EXE
-prec
=
$pr_i
-m
=
33
-n
=
313
-stride
=
1000
$EXE
-prec
=
$pr_i
-m
=
11
-n
=
510
$EXE
-prec
=
$pr_i
-m
=
171
-n
=
676
-stride
=
818
$EXE
-prec
=
$pr_i
-m
=
91
-n
=
636
$EXE
-prec
=
$pr_i
-m
=
12
-n
=
768
-stride
=
800
$EXE
-prec
=
$pr_i
-m
=
100
-n
=
766
-stride
=
812
$EXE
-prec
=
$pr_i
-m
=
31
-n
=
1024
$EXE
-prec
=
$pr_i
-m
=
64
-n
=
1000
-stride
=
1004
$EXE
-prec
=
$pr_i
-m
=
8
-n
=
1501
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
1826
$EXE
-prec
=
$pr_i
-m
=
5
-n
=
2040
$EXE
-prec
=
$pr_i
-m
=
7
-n
=
2734
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
3182
$EXE
-prec
=
$pr_i
-m
=
9
-n
=
4096
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
8192
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
10547
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
17134
done
example/ck_tile/12_smoothquant/smoothquant.cpp
0 → 100644
View file @
a4522ae3
#include "ck_tile/host.hpp"
#include "smoothquant.hpp"
#include <cstring>
// different threshold for different dtype
template
<
typename
DataType
>
auto
get_elimit
()
{
double
rtol
=
1e-5
;
double
atol
=
1e-5
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
template
<
>
auto
get_elimit
<
ck_tile
::
bf16_t
>
()
{
double
rtol
=
1e-5
;
double
atol
=
1e-5
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
template
<
>
auto
get_elimit
<
ck_tile
::
int8_t
>
()
{
// due to rounding, int8 quantization might have 1 abs error
double
rtol
=
1
;
double
atol
=
1
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"m"
,
"3328"
,
"m dimension"
)
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to n"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"prec"
,
"fp16"
,
"precision"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
.
insert
(
"repeat"
,
"20"
,
"hot iter"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
template
<
typename
DataType
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
n
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
if
(
stride
<
0
)
stride
=
n
;
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
assert
(
stride
>=
n
);
using
TypeConfig
=
SmoothquantTypeConfig
<
DataType
>
;
using
XDataType
=
typename
TypeConfig
::
XDataType
;
using
XScaleDataType
=
typename
TypeConfig
::
XScaleDataType
;
using
YScaleDataType
=
typename
TypeConfig
::
YScaleDataType
;
using
QYDataType
=
typename
TypeConfig
::
QYDataType
;
using
ComputeDataType
=
typename
TypeConfig
::
ComputeDataType
;
// host verify
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
XScaleDataType
>
xscale_host
({
n
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_ref
({
m
},
{
1
});
ck_tile
::
HostTensor
<
YScaleDataType
>
yscale_host_dev
({
m
},
{
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_ref
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
QYDataType
>
qy_host_dev
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
XScaleDataType
>
{
1e-3
,
.5
f
}(
xscale_host
);
ck_tile
::
DeviceMem
x_buf
(
x_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
xscale_buf
(
xscale_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
yscale_buf
(
yscale_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
qy_buf
(
qy_host_dev
.
get_element_space_size_in_bytes
());
x_buf
.
ToDevice
(
x_host
.
data
());
xscale_buf
.
ToDevice
(
xscale_host
.
data
());
std
::
cout
<<
"["
<<
data_type
<<
"]"
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
std
::
flush
;
smoothquant_traits
traits
{
data_type
};
smoothquant_args
args
{
x_buf
.
GetDeviceBuffer
(),
xscale_buf
.
GetDeviceBuffer
(),
yscale_buf
.
GetDeviceBuffer
(),
qy_buf
.
GetDeviceBuffer
(),
m
,
n
,
stride
};
float
ave_time
=
smoothquant
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
m
*
n
+
sizeof
(
XScaleDataType
)
*
n
+
sizeof
(
YScaleDataType
)
*
m
+
sizeof
(
QYDataType
)
*
m
*
n
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
", "
<<
ave_time
*
1.E3
<<
" us, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
flush
;
bool
pass
=
true
;
if
(
do_validation
)
{
using
YDataType
=
ComputeDataType
;
ck_tile
::
HostTensor
<
ComputeDataType
>
y_host
({
m
,
n
},
{
stride
,
1
});
// smooth outlier
{
auto
f
=
[
&
](
auto
n_
)
{
auto
v_xscale
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
xscale_host
(
n_
));
for
(
int
m_
=
0
;
m_
<
m
;
++
m_
)
{
auto
v_x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_host
(
m_
,
n_
));
y_host
(
m_
,
n_
)
=
v_x
*
v_xscale
;
}
};
ck_tile
::
make_ParallelTensorFunctor
(
f
,
xscale_host
.
get_element_space_size
())(
std
::
thread
::
hardware_concurrency
());
}
// yscale
{
ck_tile
::
HostTensor
<
YDataType
>
y_rowwise_amax_host
({
m
});
using
ReduceAmax
=
ck_tile
::
ReduceOp
::
AbsMax
;
ck_tile
::
reference_reduce
<
ComputeDataType
,
ComputeDataType
,
YDataType
>
(
y_host
,
y_rowwise_amax_host
,
ReduceAmax
{});
auto
op
=
[](
const
auto
&
v0
)
{
return
v0
/
ck_tile
::
type_convert
<
ComputeDataType
>
(
ck_tile
::
numeric
<
QYDataType
>::
max
());
};
ck_tile
::
reference_unary_elementwise
<
YDataType
,
YScaleDataType
,
ComputeDataType
>
(
y_rowwise_amax_host
,
yscale_host_ref
,
op
);
yscale_buf
.
FromDevice
(
yscale_host_dev
.
mData
.
data
());
auto
[
rtol
,
atol
]
=
get_elimit
<
YScaleDataType
>
();
pass
&=
ck_tile
::
check_err
(
yscale_host_dev
,
yscale_host_ref
,
std
::
string
(
"yscale Error: Incorrect results!"
),
rtol
,
atol
);
}
// rowwise quantization
{
ck_tile
::
reference_rowwise_quantization2d
<
YDataType
,
YScaleDataType
,
QYDataType
>
(
y_host
,
yscale_host_ref
,
qy_host_ref
);
qy_buf
.
FromDevice
(
qy_host_dev
.
data
());
auto
[
rtol
,
atol
]
=
get_elimit
<
QYDataType
>
();
if
(
stride
==
n
)
{
pass
=
ck_tile
::
check_err
(
qy_host_dev
,
qy_host_ref
,
std
::
string
(
"qy Error: Incorrect results!"
),
rtol
,
atol
);
}
else
{
for
(
int
i_r
=
0
;
i_r
<
m
;
i_r
++
)
{
std
::
vector
<
QYDataType
>
qy_host_dev_row
(
qy_host_dev
.
begin
()
+
i_r
*
stride
,
qy_host_dev
.
begin
()
+
i_r
*
stride
+
n
);
std
::
vector
<
QYDataType
>
qy_host_ref_row
(
qy_host_ref
.
begin
()
+
i_r
*
stride
,
qy_host_ref
.
begin
()
+
i_r
*
stride
+
n
);
pass
&=
ck_tile
::
check_err
(
qy_host_dev_row
,
qy_host_ref_row
,
std
::
string
(
"qy["
)
+
std
::
to_string
(
i_r
)
+
std
::
string
(
"] Error: Incorrect results!"
),
rtol
,
atol
);
}
}
}
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
}
return
pass
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
const
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
if
(
data_type
==
"fp16"
)
{
return
run
<
ck_tile
::
half_t
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
data_type
==
"bf16"
)
{
return
run
<
ck_tile
::
bf16_t
>
(
arg_parser
)
?
0
:
-
2
;
}
return
-
3
;
}
example/ck_tile/12_smoothquant/smoothquant.hpp
0 → 100644
View file @
a4522ae3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/smoothquant.hpp"
#include <string>
template
<
typename
DataType
>
struct
SmoothquantTypeConfig
;
template
<
>
struct
SmoothquantTypeConfig
<
ck_tile
::
half_t
>
{
using
XDataType
=
ck_tile
::
half_t
;
using
XScaleDataType
=
float
;
using
YScaleDataType
=
float
;
using
QYDataType
=
ck_tile
::
int8_t
;
using
ComputeDataType
=
float
;
};
template
<
>
struct
SmoothquantTypeConfig
<
ck_tile
::
bf16_t
>
{
using
XDataType
=
ck_tile
::
bf16_t
;
using
XScaleDataType
=
float
;
using
YScaleDataType
=
float
;
using
QYDataType
=
ck_tile
::
int8_t
;
using
ComputeDataType
=
float
;
};
// runtime args
struct
smoothquant_args
:
public
ck_tile
::
SmoothquantHostArgs
{
};
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
typename
DataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
,
bool
kTwoPass_
>
struct
smoothquant_traits_
{
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
bool
is_warp_per_row
=
ThreadPerBlock_N_
<=
warpSize
;
static_assert
((
ThreadPerBlock_M_
*
ThreadPerBlock_N_
)
%
warpSize
==
0
);
static
constexpr
ck_tile
::
index_t
total_warps
=
(
ThreadPerBlock_M_
*
ThreadPerBlock_N_
)
/
warpSize
;
// num of warps along m
static
constexpr
ck_tile
::
index_t
BlockWarps_M
=
[]()
{
if
constexpr
(
is_warp_per_row
)
{
static_assert
(
warpSize
%
ThreadPerBlock_N_
==
0
);
return
total_warps
*
(
warpSize
/
ThreadPerBlock_N_
);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return
total_warps
/
(
ThreadPerBlock_N_
/
warpSize
);
}
}();
// num of warps along n
static
constexpr
ck_tile
::
index_t
BlockWarps_N
=
[]()
{
if
constexpr
(
is_warp_per_row
)
{
static_assert
(
warpSize
%
ThreadPerBlock_N_
==
0
);
return
1
;
}
else
{
static_assert
(
ThreadPerBlock_N_
%
warpSize
==
0
);
return
ThreadPerBlock_N_
/
warpSize
;
}
}();
static
constexpr
ck_tile
::
index_t
Repeat_M
=
Repeat_M_
;
static
constexpr
ck_tile
::
index_t
Repeat_N
=
Repeat_N_
;
static
constexpr
ck_tile
::
index_t
Block_M
=
Repeat_M_
*
ThreadPerBlock_M_
;
static
constexpr
ck_tile
::
index_t
Block_N
=
Repeat_N_
*
ThreadPerBlock_N_
*
Vector_N_
;
static
constexpr
ck_tile
::
index_t
Warp_M
=
ThreadPerBlock_M_
/
BlockWarps_M
;
static
constexpr
ck_tile
::
index_t
Warp_N
=
ThreadPerBlock_N_
/
BlockWarps_N
*
Vector_N_
;
using
BlockTile
=
ck_tile
::
sequence
<
Block_M
,
Block_N
>
;
using
BlockWarps
=
ck_tile
::
sequence
<
BlockWarps_M
,
BlockWarps_N
>
;
using
WarpTile
=
ck_tile
::
sequence
<
Warp_M
,
Warp_N
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
Vector_N_
>
;
using
Shape
=
ck_tile
::
Generic2dBlockShape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
};
template
<
typename
Traits_
>
float
smoothquant_
(
const
ck_tile
::
stream_config
&
s
,
smoothquant_args
a
);
// This is the public API, will be generated by script
struct
smoothquant_traits
{
std
::
string
data_type
;
};
float
smoothquant
(
smoothquant_traits
,
smoothquant_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/CMakeLists.txt
View file @
a4522ae3
...
...
@@ -6,3 +6,9 @@ add_subdirectory(01_fmha)
add_subdirectory
(
02_layernorm2d
)
add_subdirectory
(
03_gemm
)
add_subdirectory
(
04_img2col
)
add_subdirectory
(
05_reduce
)
add_subdirectory
(
06_permute
)
add_subdirectory
(
09_topk_softmax
)
add_subdirectory
(
10_rmsnorm2d
)
add_subdirectory
(
11_add_rmsnorm2d_rdquant
)
add_subdirectory
(
12_smoothquant
)
include/ck/host_utility/flush_cache.hpp
View file @
a4522ae3
...
...
@@ -237,7 +237,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
Args
...
args
)
{
#if CK_TIME_KERNEL
#define MEDIAN
1
#define MEDIAN
0
if
(
stream_config
.
time_kernel_
)
{
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
...
...
@@ -275,6 +275,14 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
#else
float
total_time
=
0
;
#endif
hipEvent_t
start
,
stop
;
hip_check_error
(
hipEventCreate
(
&
start
));
hip_check_error
(
hipEventCreate
(
&
stop
));
hip_check_error
(
hipDeviceSynchronize
());
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
if
constexpr
(
!
TimePreprocess
)
...
...
@@ -282,13 +290,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
preprocess
();
}
hipEvent_t
start
,
stop
;
//
hipEvent_t start, stop;
hip_check_error
(
hipEventCreate
(
&
start
));
hip_check_error
(
hipEventCreate
(
&
stop
));
//
hip_check_error(hipEventCreate(&start));
//
hip_check_error(hipEventCreate(&stop));
hip_check_error
(
hipDeviceSynchronize
());
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
//
hip_check_error(hipDeviceSynchronize());
//
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
// calculate preprocess time
if
constexpr
(
TimePreprocess
)
{
...
...
@@ -299,25 +307,34 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
hip_check_error
(
hipGetLastError
());
// end real kernel
hip_check_error
(
hipEventRecord
(
stop
,
stream_config
.
stream_id_
));
hip_check_error
(
hipEventSynchronize
(
stop
));
float
cur_time
=
0
;
hip_check_error
(
hipEventElapsedTime
(
&
cur_time
,
start
,
stop
));
#if MEDIAN
times
.
insert
(
cur_time
);
#else
total_time
+=
cur_time
;
#endif
//
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
//
hip_check_error(hipEventSynchronize(stop));
//
float cur_time = 0;
//
hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
//
#if MEDIAN
//
times.insert(cur_time);
//
#else
//
total_time += cur_time;
//
#endif
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"i: "
<<
i
<<
" cur_time: "
<<
cur_time
<<
std
::
endl
;
//
std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
printf
(
"gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p
\n
"
,
static_cast
<
const
void
*>
(
gemm_args
.
p_a_grid
),
static_cast
<
const
void
*>
(
gemm_args
.
p_b_grid
));
}
}
hip_check_error
(
hipEventRecord
(
stop
,
stream_config
.
stream_id_
));
hip_check_error
(
hipEventSynchronize
(
stop
));
float
cur_time
=
0
;
hip_check_error
(
hipEventElapsedTime
(
&
cur_time
,
start
,
stop
));
#if MEDIAN
times
.
insert
(
cur_time
);
#else
total_time
+=
cur_time
;
#endif
#if MEDIAN
auto
mid
=
times
.
begin
();
...
...
@@ -333,7 +350,11 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
return
(
*
mid
+
*
mid_next
)
/
2
;
}
#else
return
total_time
/
nrepeat
;
// return total_time / nrepeat;
hipDeviceProp_t
deviceProps
;
hip_check_error
(
hipGetDeviceProperties
(
&
deviceProps
,
0
));
float
preprocess_offset
=
deviceProps
.
multiProcessorCount
==
80
?
0.005
:
0.01
;
return
(
total_time
-
preprocess_offset
*
nrepeat
)
/
nrepeat
;
#endif
}
else
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
a4522ae3
...
...
@@ -352,7 +352,7 @@ struct BlockwiseGemmWMMA
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
wmma_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -406,7 +406,7 @@ struct BlockwiseGemmWMMA
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
wmma_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
View file @
a4522ae3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "device_base.hpp"
...
...
@@ -31,13 +31,13 @@ struct DeviceCGemm : public BaseOperator
CElementwiseOperation
c_element_op
,
ck
::
index_t
KBatch
=
1
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
size_t
GetWorkspaceSize
(
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
)
=
0
;
index_t
StrideC
)
const
=
0
;
};
template
<
typename
AElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
View file @
a4522ae3
...
...
@@ -598,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
[[
maybe_unused
]]
index_t
K
,
[[
maybe_unused
]]
index_t
StrideA
,
[[
maybe_unused
]]
index_t
StrideB
,
index_t
StrideC
)
override
index_t
StrideC
)
const
override
{
return
2
*
sizeof
(
CDataType
)
*
GetCElementSpaceSize
(
M
,
N
,
StrideC
);
}
std
::
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
base_arg
)
const
override
{
const
auto
*
parg
=
dynamic_cast
<
const
Argument
*>
(
base_arg
);
if
(
!
parg
)
{
std
::
ostringstream
err
;
err
<<
"Provided argument pointer is not of an Argument class!"
<<
" In "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
return
GetWorkspaceSize
(
parg
->
M
,
parg
->
N
,
parg
->
K
,
parg
->
StrideA
,
parg
->
StrideB
,
parg
->
StrideC
);
}
};
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
View file @
a4522ae3
...
...
@@ -93,12 +93,12 @@ __global__ void
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
));
const
long_index_t
b_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
));
const
long_index_t
e_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
));
const
long_index_t
a_batch_offset
=
amd_wave_read_first_lane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
))
)
;
const
long_index_t
b_batch_offset
=
amd_wave_read_first_lane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
))
)
;
const
long_index_t
e_batch_offset
=
amd_wave_read_first_lane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
))
)
;
const
auto
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
View file @
a4522ae3
...
...
@@ -60,12 +60,12 @@ __global__ void
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
*
NumGroupsToMerge
);
const
index_t
k_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
num_k_per_block
);
const
long_index_t
a_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
));
const
long_index_t
b_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
));
const
long_index_t
e_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
));
const
long_index_t
a_batch_offset
=
amd_wave_read_first_lane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
))
)
;
const
long_index_t
b_batch_offset
=
amd_wave_read_first_lane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
))
)
;
const
long_index_t
e_batch_offset
=
amd_wave_read_first_lane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
))
)
;
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -117,12 +117,12 @@ __global__ void
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
*
NumGroupsToMerge
);
const
index_t
k_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
num_k_per_block
);
const
long_index_t
a_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
));
const
long_index_t
b_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
));
const
long_index_t
e_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
));
const
long_index_t
a_batch_offset
=
amd_wave_read_first_lane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
))
)
;
const
long_index_t
b_batch_offset
=
amd_wave_read_first_lane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
))
)
;
const
long_index_t
e_batch_offset
=
amd_wave_read_first_lane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
))
)
;
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
View file @
a4522ae3
...
...
@@ -98,12 +98,12 @@ __global__ void
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
));
const
long_index_t
b_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
));
const
long_index_t
c_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
));
const
long_index_t
a_batch_offset
=
amd_wave_read_first_lane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
))
)
;
const
long_index_t
b_batch_offset
=
amd_wave_read_first_lane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
))
)
;
const
long_index_t
c_batch_offset
=
amd_wave_read_first_lane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
))
)
;
const
auto
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
a4522ae3
...
...
@@ -85,9 +85,9 @@ __global__ void
BsPointer
p_bs_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_k0_m_k1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_k0_n_k1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -121,6 +121,19 @@ __global__ void
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_group_offset
[
i
];
});
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
element_wise
::
DynamicUnaryOp
>
)
{
a_element_op
.
InitUnaryOpPtrOnDevice
();
}
if
constexpr
(
is_same_v
<
BElementwiseOperation
,
element_wise
::
DynamicUnaryOp
>
)
{
b_element_op
.
InitUnaryOpPtrOnDevice
();
}
if
constexpr
(
is_same_v
<
CDEElementwiseOperation
,
element_wise
::
DynamicUnaryOp
>
)
{
cde_element_op
.
InitUnaryOpPtrOnDevice
();
}
if
constexpr
(
isMultiA
||
isMultiB
)
{
AsPointer
p_as_grid_grp
;
...
...
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
a4522ae3
...
...
@@ -272,6 +272,26 @@ struct MultiplyMultiply
e
=
ck
::
type_convert
<
ck
::
bhalf_t
>
(
x0_f
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
ck
::
half_t
,
int
,
ck
::
half_t
,
ck
::
half_t
>
(
ck
::
half_t
&
e
,
const
int
&
c
,
const
ck
::
half_t
&
d0
,
const
ck
::
half_t
&
d1
)
const
{
const
float
x0_f
=
ck
::
type_convert
<
float
>
(
c
)
*
ck
::
type_convert
<
float
>
(
d0
)
*
ck
::
type_convert
<
float
>
(
d1
);
e
=
ck
::
type_convert
<
ck
::
half_t
>
(
x0_f
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
ck
::
bhalf_t
,
int
,
float
,
float
>
(
ck
::
bhalf_t
&
e
,
const
int
&
c
,
const
float
&
d0
,
const
float
&
d1
)
const
{
const
float
x0_f
=
ck
::
type_convert
<
float
>
(
c
)
*
ck
::
type_convert
<
float
>
(
d0
)
*
ck
::
type_convert
<
float
>
(
d1
);
e
=
ck
::
type_convert
<
ck
::
bhalf_t
>
(
x0_f
);
}
};
struct
MultiplyAddFastGelu
...
...
@@ -385,7 +405,7 @@ struct ScaleAddScaleAddRelu
const
float
&
d1
)
const
{
const
float
x
=
c
*
alpha1_
+
alpha2_
*
d0
+
d1
;
Relu
{}.
template
operator
()
<
float
>(
e
,
x
)
;
e
=
x
>
0
?
x
:
0
;
}
template
<
>
...
...
@@ -396,7 +416,7 @@ struct ScaleAddScaleAddRelu
type_convert
<
float
>
(
d1
);
float
result
=
0
;
Relu
{}.
template
operator
()
<
float
>(
result
,
x
)
;
result
=
x
>
0
?
x
:
0
;
e
=
type_convert
<
half_t
>
(
result
);
}
...
...
@@ -409,7 +429,7 @@ struct ScaleAddScaleAddRelu
type_convert
<
float
>
(
d1
);
float
result
=
0
;
Relu
{}.
template
operator
()
<
float
>(
result
,
x
)
;
result
=
x
>
0
?
x
:
0
;
e
=
type_convert
<
bhalf_t
>
(
result
);
}
...
...
@@ -421,7 +441,7 @@ struct ScaleAddScaleAddRelu
const
float
x
=
type_convert
<
float
>
(
c
)
*
alpha1_
+
alpha2_
*
d0
+
d1
;
float
result
=
0
;
Relu
{}.
template
operator
()
<
float
>(
result
,
x
)
;
result
=
x
>
0
?
x
:
0
;
e
=
type_convert
<
int8_t
>
(
result
);
}
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
a4522ae3
...
...
@@ -7,11 +7,38 @@
#include "ck/utility/math.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/type_convert.hpp"
#include <cassert>
namespace
ck
{
namespace
tensor_operation
{
namespace
element_wise
{
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wnon-virtual-dtor"
struct
UnaryOpBase
{
public:
__host__
__device__
~
UnaryOpBase
()
=
default
;
__host__
__device__
constexpr
UnaryOpBase
()
=
default
;
__host__
__device__
constexpr
UnaryOpBase
(
const
UnaryOpBase
&
)
=
default
;
__host__
__device__
constexpr
UnaryOpBase
(
UnaryOpBase
&&
)
=
default
;
__host__
__device__
UnaryOpBase
&
operator
=
(
const
UnaryOpBase
&
)
=
default
;
__host__
__device__
UnaryOpBase
&
operator
=
(
UnaryOpBase
&&
)
=
default
;
__host__
__device__
virtual
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
=
0
;
};
struct
PassThroughPack2
{
template
<
typename
Y
,
typename
X
>
...
...
@@ -25,17 +52,30 @@ struct PassThroughPack2
constexpr
const
static
bool
is_pack2_invocable
=
true
;
};
struct
PassThrough
struct
PassThrough
final
:
public
UnaryOpBase
{
__host__
__device__
constexpr
PassThrough
()
=
default
;
__host__
__device__
constexpr
PassThrough
(
const
PassThrough
&
)
=
default
;
__host__
__device__
constexpr
PassThrough
(
PassThrough
&&
)
=
default
;
__host__
__device__
PassThrough
&
operator
=
(
const
PassThrough
&
)
=
default
;
__host__
__device__
PassThrough
&
operator
=
(
PassThrough
&&
)
=
default
;
__host__
__device__
~
PassThrough
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
y
=
x
;
}
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
double
,
double
>
(
double
&
y
,
const
double
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
double
>
(
float
&
y
,
const
double
&
x
)
const
{
...
...
@@ -48,36 +88,12 @@ struct PassThrough
y
=
type_convert
<
double
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
float
>
(
half_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
half_t
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
int32_t
,
int32_t
>
(
int32_t
&
y
,
const
int32_t
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
float
>
(
bhalf_t
&
y
,
const
float
&
x
)
const
{
...
...
@@ -102,12 +118,6 @@ struct PassThrough
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
int8_t
>
(
half_t
&
y
,
const
int8_t
&
x
)
const
{
...
...
@@ -407,17 +417,48 @@ struct UnarySquare
};
};
struct
UnaryAbs
struct
UnaryAbs
final
:
public
UnaryOpBase
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
constexpr
UnaryAbs
()
=
default
;
__host__
__device__
constexpr
UnaryAbs
(
const
UnaryAbs
&
)
=
default
;
__host__
__device__
constexpr
UnaryAbs
(
UnaryAbs
&&
)
=
default
;
__host__
__device__
UnaryAbs
&
operator
=
(
const
UnaryAbs
&
)
=
default
;
__host__
__device__
UnaryAbs
&
operator
=
(
UnaryAbs
&&
)
=
default
;
__host__
__device__
~
UnaryAbs
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
abs
(
x
);
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
y
=
ck
::
math
::
abs
(
x
);
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
y
=
ck
::
math
::
abs
(
x
);
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
y
=
ck
::
math
::
abs
(
x
);
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
y
=
ck
::
math
::
abs
(
x
);
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
y
=
ck
::
math
::
abs
(
x
);
}
__host__
__device__
void
operator
()(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
y
=
ck
::
type_convert
<
f8_t
>
(
ck
::
math
::
abs
(
ck
::
type_convert
<
float
>
(
x
)));
};
};
...
...
@@ -433,20 +474,41 @@ struct UnarySqrt
};
};
struct
Relu
struct
Relu
final
:
public
UnaryOpBase
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
constexpr
Relu
()
=
default
;
__host__
__device__
constexpr
Relu
(
const
Relu
&
)
=
default
;
__host__
__device__
constexpr
Relu
(
Relu
&&
)
=
default
;
__host__
__device__
Relu
&
operator
=
(
const
Relu
&
)
=
default
;
__host__
__device__
Relu
&
operator
=
(
Relu
&&
)
=
default
;
__host__
__device__
~
Relu
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
x
>
0
?
x
:
0
;
}
template
<
>
__host__
__device__
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
float
x_f32
=
ck
::
type_convert
<
float
>
(
x
);
float
y_f32
=
x_f32
>
0
?
x_f32
:
0
;
...
...
@@ -593,18 +655,52 @@ struct Gelu
}
};
struct
Sigmoid
struct
Sigmoid
final
:
public
UnaryOpBase
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
constexpr
Sigmoid
()
=
default
;
__host__
__device__
constexpr
Sigmoid
(
const
Sigmoid
&
)
=
default
;
__host__
__device__
constexpr
Sigmoid
(
Sigmoid
&&
)
=
default
;
__host__
__device__
Sigmoid
&
operator
=
(
const
Sigmoid
&
)
=
default
;
__host__
__device__
Sigmoid
&
operator
=
(
Sigmoid
&&
)
=
default
;
__host__
__device__
~
Sigmoid
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
};
constexpr
float
one
=
type_convert
<
float
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
constexpr
double
one
=
type_convert
<
double
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
constexpr
int32_t
one
=
type_convert
<
int32_t
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
constexpr
int8_t
one
=
type_convert
<
int8_t
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
constexpr
half_t
one
=
type_convert
<
half_t
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
constexpr
float
one
=
type_convert
<
float
>
(
1
);
float
x_f32
=
ck
::
type_convert
<
float
>
(
x
);
float
y_f32
=
one
/
(
one
+
ck
::
math
::
exp
(
x_f32
));
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_f32
);
}
};
struct
Silu
...
...
@@ -620,18 +716,44 @@ struct Silu
};
};
struct
TanH
struct
TanH
final
:
public
UnaryOpBase
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
constexpr
TanH
()
=
default
;
__host__
__device__
constexpr
TanH
(
const
TanH
&
)
=
default
;
__host__
__device__
constexpr
TanH
(
TanH
&&
)
=
default
;
__host__
__device__
TanH
&
operator
=
(
const
TanH
&
)
=
default
;
__host__
__device__
TanH
&
operator
=
(
TanH
&&
)
=
default
;
__host__
__device__
~
TanH
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
tanh
(
x
);
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
y
=
ck
::
math
::
tanh
(
x
);
};
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
y
=
ck
::
math
::
tanh
(
x
);
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
y
=
ck
::
math
::
tanh
(
x
);
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
y
=
ck
::
math
::
tanh
(
x
);
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
y
=
ck
::
math
::
tanh
(
x
);
}
};
struct
ACos
...
...
@@ -872,138 +994,418 @@ struct Rcp
};
};
struct
Swish
struct
Swish
final
:
public
UnaryOpBase
{
Swish
(
float
beta
=
1.0
f
)
:
beta_
(
beta
)
{}
__host__
__device__
constexpr
Swish
(
const
Swish
&
)
=
default
;
__host__
__device__
constexpr
Swish
(
Swish
&&
)
=
default
;
__host__
__device__
~
Swish
()
=
default
;
__host__
__device__
Swish
(
float
beta
=
1.0
f
)
:
beta_
(
beta
)
{}
__host__
__device__
float
get_beta
()
const
{
return
beta_
;
}
const
float
beta_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
float
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
double
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
int32_t
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
int8_t
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
half_t
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
bhalf_t
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
double
>::
value
||
is_same
<
X
,
ck
::
half_t
>::
value
,
is_same
<
X
,
half_t
>::
value
,
"Data type is not supported by this operation!"
);
static_assert
(
is_same
<
Y
,
float
>::
value
||
is_same
<
Y
,
double
>::
value
||
is_same
<
Y
,
ck
::
half_t
>::
value
,
is_same
<
Y
,
half_t
>::
value
,
"Data type is not supported by this operation!"
);
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
};
const
float
beta_
;
}
};
struct
SoftRelu
struct
SoftRelu
final
:
public
UnaryOpBase
{
SoftRelu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
__host__
__device__
constexpr
SoftRelu
(
const
SoftRelu
&
)
=
default
;
__host__
__device__
constexpr
SoftRelu
(
SoftRelu
&&
)
=
default
;
__host__
__device__
~
SoftRelu
()
=
default
;
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
SoftRelu
(
float
alpha
=
1.0
f
)
:
alpha_
(
alpha
)
{}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
const
float
alpha_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
constexpr
float
one
=
type_convert
<
float
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
constexpr
double
one
=
type_convert
<
double
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
constexpr
int32_t
one
=
type_convert
<
int32_t
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
constexpr
int8_t
one
=
type_convert
<
int8_t
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
constexpr
half_t
one
=
type_convert
<
half_t
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
constexpr
bhalf_t
one
=
type_convert
<
bhalf_t
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
const
float
alpha_
;
};
struct
Power
struct
Power
final
:
public
UnaryOpBase
{
Power
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
,
float
gamma
=
2.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
),
gamma_
(
gamma
){};
__host__
__device__
constexpr
Power
(
const
Power
&
)
=
default
;
__host__
__device__
constexpr
Power
(
Power
&&
)
=
default
;
__host__
__device__
~
Power
()
=
default
;
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
Power
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
,
float
gamma
=
2.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
),
gamma_
(
gamma
)
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
T
casted_gamma
=
type_convert
<
T
>
(
gamma_
);
T
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
__host__
__device__
float
get_beta
()
const
{
return
beta_
;
}
__host__
__device__
float
get_gamma
()
const
{
return
gamma_
;
}
const
float
alpha_
;
const
float
beta_
;
const
float
gamma_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
float
casted_beta
=
type_convert
<
float
>
(
beta_
);
float
casted_gamma
=
type_convert
<
float
>
(
gamma_
);
float
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
double
casted_beta
=
type_convert
<
double
>
(
beta_
);
double
casted_gamma
=
type_convert
<
double
>
(
gamma_
);
double
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
int32_t
casted_beta
=
type_convert
<
int32_t
>
(
beta_
);
int32_t
casted_gamma
=
type_convert
<
int32_t
>
(
gamma_
);
int32_t
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
int8_t
casted_beta
=
type_convert
<
int8_t
>
(
beta_
);
int8_t
casted_gamma
=
type_convert
<
int8_t
>
(
gamma_
);
int8_t
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
half_t
casted_beta
=
type_convert
<
half_t
>
(
beta_
);
half_t
casted_gamma
=
type_convert
<
half_t
>
(
gamma_
);
half_t
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
bhalf_t
casted_beta
=
type_convert
<
bhalf_t
>
(
beta_
);
bhalf_t
casted_gamma
=
type_convert
<
bhalf_t
>
(
gamma_
);
bhalf_t
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
};
struct
ClippedRelu
struct
ClippedRelu
final
:
public
UnaryOpBase
{
ClippedRelu
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
){};
__host__
__device__
constexpr
ClippedRelu
(
const
ClippedRelu
&
)
=
default
;
__host__
__device__
constexpr
ClippedRelu
(
ClippedRelu
&&
)
=
default
;
__host__
__device__
~
ClippedRelu
()
=
default
;
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
ClippedRelu
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
)
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
__host__
__device__
float
get_beta
()
const
{
return
beta_
;
}
const
float
alpha_
;
const
float
beta_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
float
casted_beta
=
type_convert
<
float
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
double
casted_beta
=
type_convert
<
double
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
int32_t
casted_beta
=
type_convert
<
int32_t
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
int8_t
casted_beta
=
type_convert
<
int8_t
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
half_t
casted_beta
=
type_convert
<
half_t
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
bhalf_t
casted_beta
=
type_convert
<
bhalf_t
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
};
struct
LeakyRelu
struct
LeakyRelu
final
:
public
UnaryOpBase
{
LeakyRelu
(
float
alpha
=
0.01
f
)
:
alpha_
(
alpha
){};
__host__
__device__
constexpr
LeakyRelu
(
const
LeakyRelu
&
)
=
default
;
__host__
__device__
constexpr
LeakyRelu
(
LeakyRelu
&&
)
=
default
;
__host__
__device__
~
LeakyRelu
()
=
default
;
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
LeakyRelu
(
float
alpha
=
0.
f
)
:
alpha_
(
alpha
)
{}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
const
float
alpha_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
__host__
__device__
inline
void
operator
()([[
maybe_unused
]]
bhalf_t
&
y
,
[[
maybe_unused
]]
const
bhalf_t
&
x
)
const
final
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
const
float
alpha_
;
};
struct
Elu
struct
Elu
final
:
public
UnaryOpBase
{
Elu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
__host__
__device__
constexpr
Elu
(
const
Elu
&
)
=
default
;
__host__
__device__
constexpr
Elu
(
Elu
&&
)
=
default
;
__host__
__device__
~
Elu
()
=
default
;
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
Elu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
)
{}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
const
float
alpha_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
const
float
alpha_
;
};
struct
Logistic
struct
Logistic
final
:
public
UnaryOpBase
{
Logistic
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
__host__
__device__
constexpr
Logistic
(
const
Logistic
&
)
=
default
;
__host__
__device__
constexpr
Logistic
(
Logistic
&&
)
=
default
;
__host__
__device__
~
Logistic
()
=
default
;
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
Logistic
(
float
alpha
=
1.0
f
)
:
alpha_
(
alpha
)
{}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
const
float
alpha_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
constexpr
float
one
=
type_convert
<
float
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
constexpr
double
one
=
type_convert
<
double
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
constexpr
int32_t
one
=
type_convert
<
int32_t
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
constexpr
int8_t
one
=
type_convert
<
int8_t
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
constexpr
half_t
one
=
type_convert
<
half_t
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
constexpr
bhalf_t
one
=
type_convert
<
bhalf_t
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
const
float
alpha_
;
};
struct
ConvInvscale
...
...
@@ -1068,7 +1470,7 @@ struct ConvScaleRelu
__host__
__device__
void
operator
()
<
f8_t
,
float
>
(
f8_t
&
e
,
const
float
&
c
)
const
{
float
x
;
Relu
{}
.
template
operator
()
<
float
>
(
x
,
c
*
scale_in_
*
scale_wei_
);
Relu
{}(
x
,
c
*
scale_in_
*
scale_wei_
);
e
=
type_convert
<
f8_t
>
(
x
*
scale_out_
);
};
...
...
@@ -1147,6 +1549,255 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N>
__device__
OutputArray
operator
()(
InputArray
const
&
Input
)
{
return
convert
(
Input
);
}
};
struct
DynamicUnaryOp
{
DynamicUnaryOp
&
operator
=
(
const
DynamicUnaryOp
&
other
)
{
if
(
this
!=
&
other
)
{
unary_op_ptr_
=
other
.
unary_op_ptr_
;
unary_op_type_
=
other
.
unary_op_type_
;
}
return
*
this
;
}
__host__
__device__
DynamicUnaryOp
()
=
delete
;
__host__
__device__
DynamicUnaryOp
(
const
Swish
&
swish
)
{
unary_op_type_
=
UnaryOpType
::
Swish
;
beta
=
swish
.
get_beta
();
}
__host__
__device__
DynamicUnaryOp
(
const
Swish
&&
swish
)
{
unary_op_type_
=
UnaryOpType
::
Swish
;
beta
=
swish
.
get_beta
();
}
__host__
__device__
DynamicUnaryOp
(
const
Sigmoid
&
)
{
unary_op_type_
=
UnaryOpType
::
Sigmoid
;
}
__host__
__device__
DynamicUnaryOp
(
const
Sigmoid
&&
)
{
unary_op_type_
=
UnaryOpType
::
Sigmoid
;
}
__host__
__device__
DynamicUnaryOp
(
const
PassThrough
&
)
{
unary_op_type_
=
UnaryOpType
::
PassThrough
;
}
__host__
__device__
DynamicUnaryOp
(
const
PassThrough
&&
)
{
unary_op_type_
=
UnaryOpType
::
PassThrough
;
}
__host__
__device__
DynamicUnaryOp
(
const
Logistic
&
logistic
)
{
unary_op_type_
=
UnaryOpType
::
Logistic
;
alpha
=
logistic
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
Logistic
&&
logistic
)
{
unary_op_type_
=
UnaryOpType
::
Logistic
;
alpha
=
logistic
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
TanH
&
)
{
unary_op_type_
=
UnaryOpType
::
TanH
;
}
__host__
__device__
DynamicUnaryOp
(
const
TanH
&&
)
{
unary_op_type_
=
UnaryOpType
::
TanH
;
}
__host__
__device__
DynamicUnaryOp
(
const
Relu
&
)
{
unary_op_type_
=
UnaryOpType
::
Relu
;
}
__host__
__device__
DynamicUnaryOp
(
const
Relu
&&
)
{
unary_op_type_
=
UnaryOpType
::
Relu
;
}
__host__
__device__
DynamicUnaryOp
(
const
SoftRelu
&
softrelu
)
{
unary_op_type_
=
UnaryOpType
::
SoftRelu
;
alpha
=
softrelu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
SoftRelu
&&
softrelu
)
{
unary_op_type_
=
UnaryOpType
::
SoftRelu
;
alpha
=
softrelu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
UnaryAbs
&
)
{
unary_op_type_
=
UnaryOpType
::
UnaryAbs
;
}
__host__
__device__
DynamicUnaryOp
(
const
UnaryAbs
&&
)
{
unary_op_type_
=
UnaryOpType
::
UnaryAbs
;
}
__host__
__device__
DynamicUnaryOp
(
const
Power
&
pow
)
{
unary_op_type_
=
UnaryOpType
::
Power
;
alpha
=
pow
.
get_alpha
();
beta
=
pow
.
get_beta
();
gamma
=
pow
.
get_gamma
();
}
__host__
__device__
DynamicUnaryOp
(
const
Power
&&
pow
)
{
unary_op_type_
=
UnaryOpType
::
Power
;
alpha
=
pow
.
get_alpha
();
beta
=
pow
.
get_beta
();
gamma
=
pow
.
get_gamma
();
}
__host__
__device__
DynamicUnaryOp
(
const
ClippedRelu
&
clippedrelu
)
{
unary_op_type_
=
UnaryOpType
::
ClippedRelu
;
alpha
=
clippedrelu
.
get_alpha
();
beta
=
clippedrelu
.
get_beta
();
}
__host__
__device__
DynamicUnaryOp
(
const
ClippedRelu
&&
clippedrelu
)
{
unary_op_type_
=
UnaryOpType
::
ClippedRelu
;
alpha
=
clippedrelu
.
get_alpha
();
beta
=
clippedrelu
.
get_beta
();
}
__host__
__device__
DynamicUnaryOp
(
const
LeakyRelu
&
leakyrelu
)
{
unary_op_type_
=
UnaryOpType
::
LeakyRelu
;
alpha
=
leakyrelu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
LeakyRelu
&&
leakyrelu
)
{
unary_op_type_
=
UnaryOpType
::
LeakyRelu
;
alpha
=
leakyrelu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
Elu
&
elu
)
{
unary_op_type_
=
UnaryOpType
::
Elu
;
alpha
=
elu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
Elu
&&
elu
)
{
unary_op_type_
=
UnaryOpType
::
Elu
;
alpha
=
elu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
DynamicUnaryOp
&
dynamic_op
)
:
unary_op_type_
(
dynamic_op
.
unary_op_type_
),
unary_op_ptr_
(
dynamic_op
.
unary_op_ptr_
),
alpha
(
dynamic_op
.
alpha
),
beta
(
dynamic_op
.
beta
),
gamma
(
dynamic_op
.
gamma
)
{
}
__host__
__device__
~
DynamicUnaryOp
()
{
switch
(
unary_op_type_
)
{
case
(
UnaryOpType
::
Swish
):
delete
static_cast
<
Swish
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Sigmoid
):
delete
static_cast
<
Sigmoid
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
PassThrough
):
delete
static_cast
<
PassThrough
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Logistic
):
delete
static_cast
<
Logistic
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
TanH
):
delete
static_cast
<
TanH
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Relu
):
delete
static_cast
<
Relu
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
SoftRelu
):
delete
static_cast
<
SoftRelu
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
UnaryAbs
):
delete
static_cast
<
UnaryAbs
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Power
):
delete
static_cast
<
Power
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
ClippedRelu
):
delete
static_cast
<
ClippedRelu
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
LeakyRelu
):
delete
static_cast
<
LeakyRelu
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Elu
):
delete
static_cast
<
Elu
*>
(
unary_op_ptr_
);
break
;
default:
break
;
}
}
__device__
void
InitUnaryOpPtrOnDevice
()
{
switch
(
unary_op_type_
)
{
case
(
UnaryOpType
::
Swish
):
unary_op_ptr_
=
new
Swish
(
beta
);
break
;
case
(
UnaryOpType
::
Sigmoid
):
unary_op_ptr_
=
new
Sigmoid
;
break
;
case
(
UnaryOpType
::
PassThrough
):
unary_op_ptr_
=
new
PassThrough
;
break
;
case
(
UnaryOpType
::
Logistic
):
unary_op_ptr_
=
new
Logistic
(
alpha
);
break
;
case
(
UnaryOpType
::
TanH
):
unary_op_ptr_
=
new
TanH
;
break
;
case
(
UnaryOpType
::
Relu
):
unary_op_ptr_
=
new
Relu
;
break
;
case
(
UnaryOpType
::
SoftRelu
):
unary_op_ptr_
=
new
SoftRelu
(
alpha
);
break
;
case
(
UnaryOpType
::
UnaryAbs
):
unary_op_ptr_
=
new
UnaryAbs
;
break
;
case
(
UnaryOpType
::
Power
):
unary_op_ptr_
=
new
Power
(
alpha
,
beta
,
gamma
);
break
;
case
(
UnaryOpType
::
ClippedRelu
):
unary_op_ptr_
=
new
ClippedRelu
(
alpha
,
beta
);
break
;
case
(
UnaryOpType
::
LeakyRelu
):
unary_op_ptr_
=
new
LeakyRelu
(
alpha
);
break
;
case
(
UnaryOpType
::
Elu
):
unary_op_ptr_
=
new
Elu
(
alpha
);
break
;
default:
unary_op_ptr_
=
nullptr
;
break
;
}
}
template
<
typename
Y
,
typename
X
>
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
isSupported
<
X
,
Y
>
();
unary_op_ptr_
->
operator
()(
y
,
x
);
}
template
<
typename
Y
,
typename
X
>
__host__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
isSupported
<
X
,
Y
>
();
switch
(
unary_op_type_
)
{
case
(
UnaryOpType
::
Swish
):
Swish
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
Sigmoid
):
Sigmoid
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
PassThrough
):
PassThrough
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
Logistic
):
Logistic
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
TanH
):
TanH
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
Relu
):
Relu
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
SoftRelu
):
SoftRelu
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
UnaryAbs
):
UnaryAbs
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
Power
):
Power
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
ClippedRelu
):
ClippedRelu
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
LeakyRelu
):
LeakyRelu
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
Elu
):
Elu
{}.
operator
()(
y
,
x
);
break
;
default:
break
;
}
}
template
<
typename
X
,
typename
Y
>
__device__
__host__
constexpr
void
isSupported
()
const
{
static_assert
(
std
::
is_same
<
X
,
Y
>::
value
,
"X and Y must be of the same type"
);
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
double
>::
value
||
is_same
<
X
,
bhalf_t
>::
value
||
is_same
<
X
,
half_t
>::
value
||
is_same
<
X
,
int32_t
>::
value
||
is_same
<
X
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
}
private:
enum
class
UnaryOpType
{
Swish
,
Sigmoid
,
PassThrough
,
Logistic
,
TanH
,
Relu
,
SoftRelu
,
UnaryAbs
,
Power
,
ClippedRelu
,
LeakyRelu
,
Elu
};
public:
UnaryOpType
unary_op_type_
;
UnaryOpBase
*
unary_op_ptr_
=
nullptr
;
float
alpha
;
float
beta
;
float
gamma
;
};
#pragma clang diagnostic pop
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
a4522ae3
...
...
@@ -60,12 +60,12 @@ __global__ void
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
));
const
long_index_t
b_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
));
const
long_index_t
e_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
));
const
long_index_t
a_batch_offset
=
amd_wave_read_first_lane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
))
)
;
const
long_index_t
b_batch_offset
=
amd_wave_read_first_lane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
))
)
;
const
long_index_t
e_batch_offset
=
amd_wave_read_first_lane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
))
)
;
const
auto
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
...
...
@@ -155,12 +155,12 @@ __global__ void
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
));
const
long_index_t
b_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
));
const
long_index_t
e_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
));
const
long_index_t
a_batch_offset
=
amd_wave_read_first_lane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
))
)
;
const
long_index_t
b_batch_offset
=
amd_wave_read_first_lane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
))
)
;
const
long_index_t
e_batch_offset
=
amd_wave_read_first_lane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
))
)
;
const
auto
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
View file @
a4522ae3
...
...
@@ -121,10 +121,10 @@ struct GridwiseTensorRearrange
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
// Global Memory
const
index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
));
const
index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
));
const
index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
))
)
;
const
index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
))
)
;
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global
+
a_batch_offset
,
in_grid_desc
.
GetElementSpaceSize
());
...
...
include/ck/utility/amd_xdlops.hpp
View file @
a4522ae3
...
...
@@ -327,12 +327,12 @@ struct intrin_mfma_i32_16x16x32i8<16, 16>
__device__
static
void
Run
(
const
int8x8_t
&
reg_a
,
const
int8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_16x16x32i8
(
bit_cast
<
int64_t
>
(
reg_a
),
bit_cast
<
int64_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
__builtin_amdgcn_mfma_i32_16x16x32
_
i8
(
bit_cast
<
int64_t
>
(
reg_a
),
bit_cast
<
int64_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
};
...
...
include/ck/utility/data_type.hpp
View file @
a4522ae3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -13,8 +13,24 @@ using int4_t = _BitInt(4);
using
f8_t
=
_BitInt
(
8
);
using
bf8_t
=
unsigned
_BitInt
(
8
);
inline
constexpr
auto
next_pow2
(
uint32_t
x
)
{
// Precondition: x > 1.
return
x
>
1u
?
(
1u
<<
(
32u
-
__builtin_clz
(
x
-
1u
)))
:
x
;
}
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_t, bf8_t, bool
template
<
typename
T
>
inline
constexpr
bool
is_native_type
()
{
return
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
bhalf_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
uint8_t
>::
value
||
is_same
<
T
,
f8_t
>::
value
||
is_same
<
T
,
bf8_t
>::
value
||
is_same
<
T
,
bool
>::
value
;
}
// vector_type
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
,
typename
Enable
=
void
>
struct
vector_type
;
// Caution: DO NOT REMOVE
...
...
@@ -171,7 +187,7 @@ struct scalar_type<bool>
};
template
<
typename
T
>
struct
vector_type
<
T
,
1
>
struct
vector_type
<
T
,
1
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
using
d1_t
=
T
;
using
type
=
d1_t
;
...
...
@@ -189,7 +205,8 @@ struct vector_type<T, 1>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
,
"wrong!"
);
static_assert
(
is_same
<
X
,
d1_t
>::
value
,
"Something went wrong, please check src and dst types."
);
return
data_
.
d1x1_
;
}
...
...
@@ -197,7 +214,8 @@ struct vector_type<T, 1>
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
,
"wrong!"
);
static_assert
(
is_same
<
X
,
d1_t
>::
value
,
"Something went wrong, please check src and dst types."
);
return
data_
.
d1x1_
;
}
...
...
@@ -205,7 +223,7 @@ struct vector_type<T, 1>
__device__
int
static
err
=
0
;
template
<
typename
T
>
struct
vector_type
<
T
,
2
>
struct
vector_type
<
T
,
2
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -226,7 +244,8 @@ struct vector_type<T, 2>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
"wrong!"
);
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -245,7 +264,8 @@ struct vector_type<T, 2>
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
"wrong!"
);
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -263,7 +283,7 @@ struct vector_type<T, 2>
};
template
<
typename
T
>
struct
vector_type
<
T
,
4
>
struct
vector_type
<
T
,
4
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -287,7 +307,7 @@ struct vector_type<T, 4>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -311,7 +331,7 @@ struct vector_type<T, 4>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -333,7 +353,7 @@ struct vector_type<T, 4>
};
template
<
typename
T
>
struct
vector_type
<
T
,
8
>
struct
vector_type
<
T
,
8
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -360,7 +380,7 @@ struct vector_type<T, 8>
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -389,7 +409,7 @@ struct vector_type<T, 8>
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -415,7 +435,7 @@ struct vector_type<T, 8>
};
template
<
typename
T
>
struct
vector_type
<
T
,
16
>
struct
vector_type
<
T
,
16
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -445,7 +465,7 @@ struct vector_type<T, 16>
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -479,7 +499,7 @@ struct vector_type<T, 16>
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -509,7 +529,7 @@ struct vector_type<T, 16>
};
template
<
typename
T
>
struct
vector_type
<
T
,
32
>
struct
vector_type
<
T
,
32
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -541,7 +561,7 @@ struct vector_type<T, 32>
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -579,7 +599,7 @@ struct vector_type<T, 32>
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -613,7 +633,7 @@ struct vector_type<T, 32>
};
template
<
typename
T
>
struct
vector_type
<
T
,
64
>
struct
vector_type
<
T
,
64
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -648,7 +668,7 @@ struct vector_type<T, 64>
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -691,7 +711,7 @@ struct vector_type<T, 64>
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -729,7 +749,7 @@ struct vector_type<T, 64>
};
template
<
typename
T
>
struct
vector_type
<
T
,
128
>
struct
vector_type
<
T
,
128
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -766,7 +786,7 @@ struct vector_type<T, 128>
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
||
is_same
<
X
,
d128_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -813,7 +833,7 @@ struct vector_type<T, 128>
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
||
is_same
<
X
,
d128_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -855,7 +875,7 @@ struct vector_type<T, 128>
};
template
<
typename
T
>
struct
vector_type
<
T
,
256
>
struct
vector_type
<
T
,
256
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>
>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -894,7 +914,7 @@ struct vector_type<T, 256>
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
||
is_same
<
X
,
d128_t
>::
value
||
is_same
<
X
,
d256_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -945,7 +965,7 @@ struct vector_type<T, 256>
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
||
is_same
<
X
,
d128_t
>::
value
||
is_same
<
X
,
d256_t
>::
value
,
"
wrong!
"
);
"
Something went wrong, please check src and dst types.
"
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
...
...
@@ -990,177 +1010,752 @@ struct vector_type<T, 256>
}
};
using
int64_t
=
long
;
// fp64
using
double2_t
=
typename
vector_type
<
double
,
2
>::
type
;
using
double4_t
=
typename
vector_type
<
double
,
4
>::
type
;
// fp32
using
float2_t
=
typename
vector_type
<
float
,
2
>::
type
;
using
float4_t
=
typename
vector_type
<
float
,
4
>::
type
;
using
float8_t
=
typename
vector_type
<
float
,
8
>::
type
;
using
float16_t
=
typename
vector_type
<
float
,
16
>::
type
;
using
float32_t
=
typename
vector_type
<
float
,
32
>::
type
;
using
float64_t
=
typename
vector_type
<
float
,
64
>::
type
;
// fp16
using
half2_t
=
typename
vector_type
<
half_t
,
2
>::
type
;
using
half4_t
=
typename
vector_type
<
half_t
,
4
>::
type
;
using
half8_t
=
typename
vector_type
<
half_t
,
8
>::
type
;
using
half16_t
=
typename
vector_type
<
half_t
,
16
>::
type
;
using
half32_t
=
typename
vector_type
<
half_t
,
32
>::
type
;
using
half64_t
=
typename
vector_type
<
half_t
,
64
>::
type
;
// bfp16
using
bhalf2_t
=
typename
vector_type
<
bhalf_t
,
2
>::
type
;
using
bhalf4_t
=
typename
vector_type
<
bhalf_t
,
4
>::
type
;
using
bhalf8_t
=
typename
vector_type
<
bhalf_t
,
8
>::
type
;
using
bhalf16_t
=
typename
vector_type
<
bhalf_t
,
16
>::
type
;
using
bhalf32_t
=
typename
vector_type
<
bhalf_t
,
32
>::
type
;
using
bhalf64_t
=
typename
vector_type
<
bhalf_t
,
64
>::
type
;
// i32
using
int32x2_t
=
typename
vector_type
<
int32_t
,
2
>::
type
;
using
int32x4_t
=
typename
vector_type
<
int32_t
,
4
>::
type
;
using
int32x8_t
=
typename
vector_type
<
int32_t
,
8
>::
type
;
using
int32x16_t
=
typename
vector_type
<
int32_t
,
16
>::
type
;
using
int32x32_t
=
typename
vector_type
<
int32_t
,
32
>::
type
;
using
int32x64_t
=
typename
vector_type
<
int32_t
,
64
>::
type
;
// i8
using
int8x2_t
=
typename
vector_type
<
int8_t
,
2
>::
type
;
using
int8x4_t
=
typename
vector_type
<
int8_t
,
4
>::
type
;
using
int8x8_t
=
typename
vector_type
<
int8_t
,
8
>::
type
;
using
int8x16_t
=
typename
vector_type
<
int8_t
,
16
>::
type
;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
template
<
typename
T
,
index_t
N
>
struct
non_native_vector_base
{
using
type
=
non_native_vector_base
<
T
,
N
>
;
// f8
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
__host__
__device__
non_native_vector_base
()
=
default
;
__host__
__device__
non_native_vector_base
(
const
type
&
)
=
default
;
__host__
__device__
non_native_vector_base
(
type
&&
)
=
default
;
__host__
__device__
~
non_native_vector_base
()
=
default
;
// bf8
using
bf8x2_t
=
typename
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x4_t
=
typename
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x8_t
=
typename
vector_type
<
bf8_t
,
8
>::
type
;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
// u8
// i8
using
uint8x2_t
=
typename
vector_type
<
uint8_t
,
2
>::
type
;
using
uint8x4_t
=
typename
vector_type
<
uint8_t
,
4
>::
type
;
using
uint8x8_t
=
typename
vector_type
<
uint8_t
,
8
>::
type
;
using
uint8x16_t
=
typename
vector_type
<
uint8_t
,
16
>::
type
;
using
uint8x32_t
=
typename
vector_type
<
uint8_t
,
32
>::
type
;
using
uint8x64_t
=
typename
vector_type
<
uint8_t
,
64
>::
type
;
T
d
[
N
];
};
// non-native vector_type implementation
template
<
typename
T
>
struct
NumericLimits
struct
vector_type
<
T
,
1
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
__host__
__device__
static
constexpr
T
Min
()
{
return
std
::
numeric_limits
<
T
>::
min
();
}
using
d1_t
=
T
;
using
type
=
d1_t
;
__host__
__device__
static
constexpr
T
Max
()
{
return
std
::
numeric_limits
<
T
>::
max
();
}
union
alignas
(
next_pow2
(
1
*
sizeof
(
T
)))
{
d1_t
d1_
;
StaticallyIndexedArray
<
d1_t
,
1
>
d1x1_
;
}
data_
;
__host__
__device__
static
constexpr
T
Lowest
()
{
return
std
::
numeric_limits
<
T
>::
lowest
();
}
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{
}
__host__
__device__
static
constexpr
T
QuietNaN
()
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
return
std
::
numeric_limits
<
T
>::
quiet_NaN
();
static_assert
(
is_same
<
X
,
d1_t
>::
value
,
"Something went wrong, please check src and dst types."
);
return
data_
.
d1x1_
;
}
__host__
__device__
static
constexpr
T
Infinity
()
{
return
std
::
numeric_limits
<
T
>::
infinity
();
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
,
"Something went wrong, please check src and dst types."
);
return
data_
.
d1x1_
;
}
};
template
<
>
struct
NumericLimits
<
half_t
>
template
<
typename
T
>
struct
vector_type
<
T
,
2
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>
>
{
static
constexpr
unsigned
short
binary_min
=
0x0400
;
static
constexpr
unsigned
short
binary_max
=
0x7BFF
;
static
constexpr
unsigned
short
binary_lowest
=
0xFBFF
;
static
constexpr
unsigned
short
binary_qnan
=
0x7FFF
;
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
__host__
__device__
static
constexpr
half_t
Min
()
{
return
bit_cast
<
half_t
>
(
binary_min
);
}
using
type
=
d2_t
;
__host__
__device__
static
constexpr
half_t
Max
()
{
return
bit_cast
<
half_t
>
(
binary_max
);
}
union
alignas
(
next_pow2
(
2
*
sizeof
(
T
)))
{
d2_t
d2_
;
StaticallyIndexedArray
<
d1_t
,
2
>
d1x2_
;
StaticallyIndexedArray
<
d2_t
,
1
>
d2x1_
;
}
data_
;
__host__
__device__
static
constexpr
half_t
Lowest
()
{
return
bit_cast
<
half_t
>
(
binary_lowest
);
}
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{
}
__host__
__device__
static
constexpr
half_t
QuietNaN
()
{
return
bit_cast
<
half_t
>
(
binary_qnan
);
}
};
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
struct
NumericLimits
<
int4_t
>
{
__host__
__device__
static
constexpr
int4_t
Min
()
{
return
int4_t
(
-
8
);
}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
"Something went wrong, please check src and dst types."
);
__host__
__device__
static
constexpr
int4_t
Max
()
{
return
int4_t
(
7
);
}
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x1_
;
}
else
{
return
err
;
}
}
__host__
__device__
static
constexpr
int4_t
Lowest
()
{
return
int4_t
(
-
8
);
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x1_
;
}
else
{
return
err
;
}
}
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
struct
NumericLimits
<
f8_t
>
template
<
typename
T
>
struct
vector_type
<
T
,
4
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>
>
{
// negative zero nan mode with exp bias = 8
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000
static
constexpr
uint8_t
binary_max
=
0x7F
;
// 0b01111111
static
constexpr
uint8_t
binary_lowest
=
0xFF
;
// 0b11111111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
// ieee mode with exp bias = 7
// static constexpr uint8_t binary_min = 0x08; // 0b00001000
// static constexpr uint8_t binary_max = 0x77; // 0b01110111
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__
__device__
static
constexpr
f8_t
Min
()
{
return
f8_t
(
binary_min
);
}
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
__host__
__device__
static
constexpr
f8_t
Max
()
{
return
f8_t
(
binary_max
);
}
using
type
=
d4_t
;
__host__
__device__
static
constexpr
f8_t
Lowest
()
{
return
f8_t
(
binary_lowest
);
}
union
alignas
(
next_pow2
(
4
*
sizeof
(
T
)))
{
d4_t
d4_
;
StaticallyIndexedArray
<
d1_t
,
4
>
d1x4_
;
StaticallyIndexedArray
<
d2_t
,
2
>
d2x2_
;
StaticallyIndexedArray
<
d4_t
,
1
>
d4x1_
;
}
data_
;
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
f8_t
(
binary_qnan
);
}
};
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{}
template
<
>
struct
NumericLimits
<
bf8_t
>
{
// negative zero nan mode with exp bias = 16
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100
static
constexpr
uint8_t
binary_max
=
0x7F
;
// 0b01111111
static
constexpr
uint8_t
binary_lowest
=
0xFF
;
// 0b11111111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
// ieee mode with exp bias = 15
// static constexpr uint8_t binary_min = 0x04; // 0b00000100
// static constexpr uint8_t binary_max = 0x7B; // 0b01111011
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
__host__
__device__
static
constexpr
bf8_t
Min
()
{
return
bf8_t
(
binary_min
);
}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
"Something went wrong, please check src and dst types."
);
__host__
__device__
static
constexpr
bf8_t
Max
()
{
return
bf8_t
(
binary_max
);
}
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x1_
;
}
else
{
return
err
;
}
}
__host__
__device__
static
constexpr
bf8_t
Lowest
()
{
return
bf8_t
(
binary_lowest
);
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
"Something went wrong, please check src and dst types."
);
__host__
__device__
static
constexpr
bf8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x1_
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
>
struct
NumericUtils
struct
vector_type
<
T
,
8
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
};
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
using
d8_t
=
non_native_vector_base
<
T
,
8
>
;
template
<
>
struct
NumericUtils
<
float
>
{
static
constexpr
int
exp
=
8
;
using
type
=
d8_t
;
union
alignas
(
next_pow2
(
8
*
sizeof
(
T
)))
{
d8_t
d8_
;
StaticallyIndexedArray
<
d1_t
,
8
>
d1x8_
;
StaticallyIndexedArray
<
d2_t
,
4
>
d2x4_
;
StaticallyIndexedArray
<
d4_t
,
2
>
d4x2_
;
StaticallyIndexedArray
<
d8_t
,
1
>
d8x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x1_
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
>
struct
vector_type
<
T
,
16
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
using
d8_t
=
non_native_vector_base
<
T
,
8
>
;
using
d16_t
=
non_native_vector_base
<
T
,
16
>
;
using
type
=
d16_t
;
union
alignas
(
next_pow2
(
16
*
sizeof
(
T
)))
{
d16_t
d16_
;
StaticallyIndexedArray
<
d1_t
,
16
>
d1x16_
;
StaticallyIndexedArray
<
d2_t
,
8
>
d2x8_
;
StaticallyIndexedArray
<
d4_t
,
4
>
d4x4_
;
StaticallyIndexedArray
<
d8_t
,
2
>
d8x2_
;
StaticallyIndexedArray
<
d16_t
,
1
>
d16x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x1_
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
>
struct
vector_type
<
T
,
32
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
using
d8_t
=
non_native_vector_base
<
T
,
8
>
;
using
d16_t
=
non_native_vector_base
<
T
,
16
>
;
using
d32_t
=
non_native_vector_base
<
T
,
32
>
;
using
type
=
d32_t
;
union
alignas
(
next_pow2
(
32
*
sizeof
(
T
)))
{
d32_t
d32_
;
StaticallyIndexedArray
<
d1_t
,
32
>
d1x32_
;
StaticallyIndexedArray
<
d2_t
,
16
>
d2x16_
;
StaticallyIndexedArray
<
d4_t
,
8
>
d4x8_
;
StaticallyIndexedArray
<
d8_t
,
4
>
d8x4_
;
StaticallyIndexedArray
<
d16_t
,
2
>
d16x2_
;
StaticallyIndexedArray
<
d32_t
,
1
>
d32x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x32_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
{
return
data_
.
d32x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x32_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
{
return
data_
.
d32x1_
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
>
struct
vector_type
<
T
,
64
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
using
d4_t
=
non_native_vector_base
<
T
,
4
>
;
using
d8_t
=
non_native_vector_base
<
T
,
8
>
;
using
d16_t
=
non_native_vector_base
<
T
,
16
>
;
using
d32_t
=
non_native_vector_base
<
T
,
32
>
;
using
d64_t
=
non_native_vector_base
<
T
,
64
>
;
using
type
=
d64_t
;
union
alignas
(
next_pow2
(
64
*
sizeof
(
T
)))
{
d64_t
d64_
;
StaticallyIndexedArray
<
d1_t
,
64
>
d1x64_
;
StaticallyIndexedArray
<
d2_t
,
32
>
d2x32_
;
StaticallyIndexedArray
<
d4_t
,
16
>
d4x16_
;
StaticallyIndexedArray
<
d8_t
,
8
>
d8x8_
;
StaticallyIndexedArray
<
d16_t
,
4
>
d16x4_
;
StaticallyIndexedArray
<
d32_t
,
2
>
d32x2_
;
StaticallyIndexedArray
<
d64_t
,
1
>
d64x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x64_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x32_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
{
return
data_
.
d32x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d64_t
>::
value
)
{
return
data_
.
d64x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x64_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x32_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
{
return
data_
.
d32x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d64_t
>::
value
)
{
return
data_
.
d64x1_
;
}
else
{
return
err
;
}
}
};
using
int64_t
=
long
;
// fp64
using
double2_t
=
typename
vector_type
<
double
,
2
>::
type
;
using
double4_t
=
typename
vector_type
<
double
,
4
>::
type
;
// fp32
using
float2_t
=
typename
vector_type
<
float
,
2
>::
type
;
using
float4_t
=
typename
vector_type
<
float
,
4
>::
type
;
using
float8_t
=
typename
vector_type
<
float
,
8
>::
type
;
using
float16_t
=
typename
vector_type
<
float
,
16
>::
type
;
using
float32_t
=
typename
vector_type
<
float
,
32
>::
type
;
using
float64_t
=
typename
vector_type
<
float
,
64
>::
type
;
// fp16
using
half2_t
=
typename
vector_type
<
half_t
,
2
>::
type
;
using
half4_t
=
typename
vector_type
<
half_t
,
4
>::
type
;
using
half8_t
=
typename
vector_type
<
half_t
,
8
>::
type
;
using
half16_t
=
typename
vector_type
<
half_t
,
16
>::
type
;
using
half32_t
=
typename
vector_type
<
half_t
,
32
>::
type
;
using
half64_t
=
typename
vector_type
<
half_t
,
64
>::
type
;
// bfp16
using
bhalf2_t
=
typename
vector_type
<
bhalf_t
,
2
>::
type
;
using
bhalf4_t
=
typename
vector_type
<
bhalf_t
,
4
>::
type
;
using
bhalf8_t
=
typename
vector_type
<
bhalf_t
,
8
>::
type
;
using
bhalf16_t
=
typename
vector_type
<
bhalf_t
,
16
>::
type
;
using
bhalf32_t
=
typename
vector_type
<
bhalf_t
,
32
>::
type
;
using
bhalf64_t
=
typename
vector_type
<
bhalf_t
,
64
>::
type
;
// i32
using
int32x2_t
=
typename
vector_type
<
int32_t
,
2
>::
type
;
using
int32x4_t
=
typename
vector_type
<
int32_t
,
4
>::
type
;
using
int32x8_t
=
typename
vector_type
<
int32_t
,
8
>::
type
;
using
int32x16_t
=
typename
vector_type
<
int32_t
,
16
>::
type
;
using
int32x32_t
=
typename
vector_type
<
int32_t
,
32
>::
type
;
using
int32x64_t
=
typename
vector_type
<
int32_t
,
64
>::
type
;
// i8
using
int8x2_t
=
typename
vector_type
<
int8_t
,
2
>::
type
;
using
int8x4_t
=
typename
vector_type
<
int8_t
,
4
>::
type
;
using
int8x8_t
=
typename
vector_type
<
int8_t
,
8
>::
type
;
using
int8x16_t
=
typename
vector_type
<
int8_t
,
16
>::
type
;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
// f8
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
// bf8
using
bf8x2_t
=
typename
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x4_t
=
typename
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x8_t
=
typename
vector_type
<
bf8_t
,
8
>::
type
;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
// u8
using
uint8x2_t
=
typename
vector_type
<
uint8_t
,
2
>::
type
;
using
uint8x4_t
=
typename
vector_type
<
uint8_t
,
4
>::
type
;
using
uint8x8_t
=
typename
vector_type
<
uint8_t
,
8
>::
type
;
using
uint8x16_t
=
typename
vector_type
<
uint8_t
,
16
>::
type
;
using
uint8x32_t
=
typename
vector_type
<
uint8_t
,
32
>::
type
;
using
uint8x64_t
=
typename
vector_type
<
uint8_t
,
64
>::
type
;
template
<
typename
T
>
struct
NumericLimits
{
__host__
__device__
static
constexpr
T
Min
()
{
return
std
::
numeric_limits
<
T
>::
min
();
}
__host__
__device__
static
constexpr
T
Max
()
{
return
std
::
numeric_limits
<
T
>::
max
();
}
__host__
__device__
static
constexpr
T
Lowest
()
{
return
std
::
numeric_limits
<
T
>::
lowest
();
}
__host__
__device__
static
constexpr
T
QuietNaN
()
{
return
std
::
numeric_limits
<
T
>::
quiet_NaN
();
}
__host__
__device__
static
constexpr
T
Infinity
()
{
return
std
::
numeric_limits
<
T
>::
infinity
();
}
};
template
<
>
struct
NumericLimits
<
half_t
>
{
static
constexpr
unsigned
short
binary_min
=
0x0400
;
static
constexpr
unsigned
short
binary_max
=
0x7BFF
;
static
constexpr
unsigned
short
binary_lowest
=
0xFBFF
;
static
constexpr
unsigned
short
binary_qnan
=
0x7FFF
;
__host__
__device__
static
constexpr
half_t
Min
()
{
return
bit_cast
<
half_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
half_t
Max
()
{
return
bit_cast
<
half_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
half_t
Lowest
()
{
return
bit_cast
<
half_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
half_t
QuietNaN
()
{
return
bit_cast
<
half_t
>
(
binary_qnan
);
}
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
struct
NumericLimits
<
int4_t
>
{
__host__
__device__
static
constexpr
int4_t
Min
()
{
return
int4_t
(
-
8
);
}
__host__
__device__
static
constexpr
int4_t
Max
()
{
return
int4_t
(
7
);
}
__host__
__device__
static
constexpr
int4_t
Lowest
()
{
return
int4_t
(
-
8
);
}
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
struct
NumericLimits
<
f8_t
>
{
// negative zero nan mode with exp bias = 8
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000
static
constexpr
uint8_t
binary_max
=
0x7F
;
// 0b01111111
static
constexpr
uint8_t
binary_lowest
=
0xFF
;
// 0b11111111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
// ieee mode with exp bias = 7
// static constexpr uint8_t binary_min = 0x08; // 0b00001000
// static constexpr uint8_t binary_max = 0x77; // 0b01110111
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__
__device__
static
constexpr
f8_t
Min
()
{
return
f8_t
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_t
Max
()
{
return
f8_t
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_t
Lowest
()
{
return
f8_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
f8_t
(
binary_qnan
);
}
};
template
<
>
struct
NumericLimits
<
bf8_t
>
{
// negative zero nan mode with exp bias = 16
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100
static
constexpr
uint8_t
binary_max
=
0x7F
;
// 0b01111111
static
constexpr
uint8_t
binary_lowest
=
0xFF
;
// 0b11111111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
// ieee mode with exp bias = 15
// static constexpr uint8_t binary_min = 0x04; // 0b00000100
// static constexpr uint8_t binary_max = 0x7B; // 0b01111011
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__
__device__
static
constexpr
bf8_t
Min
()
{
return
bf8_t
(
binary_min
);
}
__host__
__device__
static
constexpr
bf8_t
Max
()
{
return
bf8_t
(
binary_max
);
}
__host__
__device__
static
constexpr
bf8_t
Lowest
()
{
return
bf8_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
bf8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
};
template
<
typename
T
>
struct
NumericUtils
{
};
template
<
>
struct
NumericUtils
<
float
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
23
;
static
constexpr
int
bias
=
127
;
static
constexpr
uint32_t
nan_mask
=
0x7F800000
;
...
...
@@ -1208,4 +1803,13 @@ struct NumericUtils<bf8_t>
static
constexpr
int
bias
=
16
;
// negative zero nan mode
// static constexpr int bias = 15; // ieee mode
};
template
<
>
struct
NumericUtils
<
bhalf_t
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
7
;
static
constexpr
int
bias
=
128
;
// negative zero nan mode
// static constexpr int bias = 127; // ieee mode
};
}
// namespace ck
include/ck/utility/math_v2.hpp
View file @
a4522ae3
...
...
@@ -80,6 +80,8 @@ static inline __host__ bool isnan(half_t x)
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
static
inline
__host__
bool
isnan
(
f8_t
x
)
{
return
(
x
&
0x80
);
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static
inline
__host__
bool
isnan
(
int4_t
x
)
{
...
...
@@ -529,6 +531,8 @@ static inline __device__ bool isnan(half_t x)
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
static
inline
__device__
bool
isnan
(
f8_t
x
)
{
return
(
x
&
0x80
);
};
static
inline
__device__
half_t
sqrt
(
half_t
x
)
{
return
static_cast
<
half_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
...
...
@@ -649,7 +653,7 @@ inline __device__ double sin<double>(double x)
template
<
>
inline
__device__
half_t
sin
<
half_t
>
(
half_t
x
)
{
return
::
hsin
(
x
);
return
hsin
(
static_cast
<
__half
>
(
x
)
);
};
template
<
typename
T
>
...
...
@@ -781,7 +785,7 @@ inline __device__ double ceil<double>(double x)
template
<
>
inline
__device__
half_t
ceil
<
half_t
>
(
half_t
x
)
{
return
::
hceil
(
x
);
return
hceil
(
static_cast
<
__half
>
(
x
)
);
};
template
<
typename
T
>
...
...
@@ -823,7 +827,7 @@ inline __device__ double floor<double>(double x)
template
<
>
inline
__device__
half_t
floor
<
half_t
>
(
half_t
x
)
{
return
::
hfloor
(
x
);
return
hfloor
(
static_cast
<
__half
>
(
x
)
);
};
template
<
typename
T
>
...
...
@@ -845,7 +849,7 @@ inline __device__ T exp(T x)
template
<
>
inline
__device__
half_t
exp
<
half_t
>
(
half_t
x
)
{
return
hexp
(
x
);
return
hexp
(
static_cast
<
__half
>
(
x
)
);
};
template
<
>
...
...
@@ -869,7 +873,7 @@ inline __device__ T log(T x)
template
<
>
inline
__device__
half_t
log
<
half_t
>
(
half_t
x
)
{
return
hlog
(
x
);
return
hlog
(
static_cast
<
__half
>
(
x
)
);
};
template
<
>
...
...
include/ck_tile/core.hpp
View file @
a4522ae3
...
...
@@ -5,6 +5,7 @@
#include "ck_tile/core/algorithm/cluster_descriptor.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/arch.hpp"
...
...
@@ -24,6 +25,7 @@
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/int8.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
...
...
@@ -49,13 +51,17 @@
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/literals.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/reduce_operator.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment