Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d43cd4ad
Commit
d43cd4ad
authored
Sep 25, 2024
by
Mirza Halilcevic
Browse files
Introduce gemm_softmax_gemm to codegen.
parent
3528a523
Changes
52
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
654 additions
and
76 deletions
+654
-76
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+6
-2
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp
...operation/gpu/device/device_batched_gemm_softmax_gemm.hpp
+4
-1
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
...ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
+7
-1
include/ck/tensor_operation/gpu/device/gemm_specialization.hpp
...de/ck/tensor_operation/gpu/device/gemm_specialization.hpp
+2
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
...ce/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
+381
-23
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
...n/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
+34
-23
include/ck/tensor_operation/gpu/device/masking_specialization.hpp
...ck/tensor_operation/gpu/device/masking_specialization.hpp
+3
-1
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
+2
-0
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+2
-2
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+4
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+4
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
...or_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
+6
-0
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+2
-0
include/ck/utility/amd_wave_read_first_lane.hpp
include/ck/utility/amd_wave_read_first_lane.hpp
+12
-10
include/ck/utility/array.hpp
include/ck/utility/array.hpp
+3
-1
include/ck/utility/container_helper.hpp
include/ck/utility/container_helper.hpp
+2
-2
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+158
-0
include/ck/utility/enable_if.hpp
include/ck/utility/enable_if.hpp
+16
-1
include/ck/utility/env.hpp
include/ck/utility/env.hpp
+4
-0
No files found.
include/ck/tensor_operation/gpu/device/device_base.hpp
View file @
d43cd4ad
...
...
@@ -3,15 +3,17 @@
#pragma once
#ifndef __HIPCC_RTC__
#include <string>
#include <sstream>
#include "ck/stream_config.hpp"
#endif
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
#ifndef __HIPCC_RTC__
struct
BaseArgument
{
BaseArgument
()
=
default
;
...
...
@@ -36,6 +38,7 @@ struct BaseInvoker
virtual
~
BaseInvoker
()
{}
};
#endif
struct
BaseOperator
{
...
...
@@ -43,6 +46,7 @@ struct BaseOperator
BaseOperator
(
const
BaseOperator
&
)
=
default
;
BaseOperator
&
operator
=
(
const
BaseOperator
&
)
=
default
;
#ifndef __HIPCC_RTC__
virtual
bool
IsSupportedArgument
(
const
BaseArgument
*
)
{
return
false
;
}
virtual
std
::
string
GetTypeString
()
const
{
return
""
;
}
...
...
@@ -66,7 +70,7 @@ struct BaseOperator
assert
(
p_arg
);
p_arg
->
p_workspace_
=
p_workspace
;
}
#endif
virtual
~
BaseOperator
()
{}
};
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp
View file @
d43cd4ad
...
...
@@ -2,9 +2,10 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <vector>
#endif
#include "device_base.hpp"
...
...
@@ -28,6 +29,7 @@ template <typename ALayout,
bool
MaskOutUpperTriangle
>
// TODO: enum for mask type
struct
DeviceBatchedGemmSoftmaxGemm
:
public
BaseOperator
{
#ifndef __HIPCC_RTC__
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b0
,
...
...
@@ -53,6 +55,7 @@ struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
#endif
};
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
View file @
d43cd4ad
...
...
@@ -2,9 +2,11 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef __HIPCC_RTC__
#include <array>
#endif
#include "ck/utility/array.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
...
...
@@ -34,6 +36,7 @@ struct DeviceGemmMultipleD : public BaseOperator
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
#ifndef __HIPCC_RTC__
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
...
...
@@ -51,6 +54,7 @@ struct DeviceGemmMultipleD : public BaseOperator
CDEElementwiseOperation
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
#endif
};
// GEMM:
...
...
@@ -76,6 +80,7 @@ struct DeviceGemmMultipleDSplitK : public BaseOperator
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
#ifndef __HIPCC_RTC__
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
...
...
@@ -94,6 +99,7 @@ struct DeviceGemmMultipleDSplitK : public BaseOperator
CDEElementwiseOperation
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
#endif
};
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/gemm_specialization.hpp
View file @
d43cd4ad
...
...
@@ -28,7 +28,7 @@ enum struct GemmSpecialization
NKOPadding
,
MNKOPadding
,
};
#ifndef __HIPCC_RTC__
inline
std
::
string
getGemmSpecializationString
(
const
GemmSpecialization
&
s
)
{
switch
(
s
)
...
...
@@ -52,6 +52,7 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s)
default:
return
"Unrecognized specialization!"
;
}
}
#endif
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
View file @
d43cd4ad
...
...
@@ -3,8 +3,12 @@
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
...
...
@@ -15,8 +19,6 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -40,7 +42,7 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
...
...
@@ -430,6 +432,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
matrix_padder
.
PadN
,
MaskOutUpperTriangle
>
;
#ifndef __HIPCC_RTC__
// Argument
struct
Argument
:
public
BaseArgument
{
...
...
@@ -604,6 +607,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
#endif
static
constexpr
bool
IsValidCompilationParameter
()
{
...
...
@@ -611,6 +615,97 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
true
;
}
static
constexpr
bool
IsSupported
(
index_t
MRaw_
,
index_t
NRaw_
,
index_t
KRaw_
,
index_t
Gemm1NRaw_
)
{
// check vector load/store
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// check vector load of A
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
)
{
if
(
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
)
{
if
(
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of B
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
)
{
if
(
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
)
{
if
(
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of B1
if
constexpr
(
is_same_v
<
B1Layout
,
Row
>
)
{
if
(
Gemm1NRaw_
%
B1BlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
B1Layout
,
Col
>
)
{
if
(
NRaw_
%
B1BlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of C
if
constexpr
(
is_same_v
<
CLayout
,
Row
>
)
{
if
(
Gemm1NRaw_
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
CLayout
,
Col
>
)
{
if
(
MRaw_
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
return
true
;
}
#ifndef __HIPCC_RTC__
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
...
...
@@ -765,6 +860,269 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
str
.
str
();
}
#endif
template
<
class
ADesc
,
class
BDesc
,
class
B1Desc
,
class
CDesc
>
struct
Descriptor
{
template
<
class
AGridDescriptor
>
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
AGridDescriptor
&
a_grid_desc
)
{
const
auto
a_grid_desc_m_k
=
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc
);
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
class
BGridDescriptor
>
static
constexpr
auto
MakeBGridDescriptor_BK0_N_BK1
(
const
BGridDescriptor
&
b_grid_desc
)
{
const
auto
b_grid_desc_n_k
=
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
class
B1GridDescriptor
>
static
constexpr
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
B1GridDescriptor
&
b1_grid_desc
)
{
const
auto
b1_grid_desc_n_k
=
DeviceOp
::
matrix_padder
.
PadB1Descriptor_N_K
(
b1_grid_desc
);
const
auto
N
=
b1_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b1_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
B1K0
=
K
/
B1K1
;
return
transform_tensor_descriptor
(
b1_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
class
CGridDescriptor
>
static
constexpr
auto
MakeCGridDescriptor_M_N
(
const
CGridDescriptor
&
c_grid_desc
)
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
c_grid_desc
);
}
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
ADesc
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
BDesc
{}))
>
;
using
B1GridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
(
B1Desc
{}))
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M_N
(
CDesc
{}))
>
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
CShuffleDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
B1GridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
AK1
,
BK1
,
B1K1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
Gemm1NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
true
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
true
,
BBlockLdsExtraN
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcVectorDim
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
false
,
B1BlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
matrix_padder
.
PadN
,
MaskOutUpperTriangle
>
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
;
CGridDesc_M_N
c_grid_desc_m_n
;
C0MatrixMask
c0_matrix_mask
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_descriptor_mblock_mperblock_nblock_nperblock
;
// element-wise op
AElementwiseOperation
a_element_op
;
BElementwiseOperation
b_element_op
;
B1ElementwiseOperation
b1_element_op
;
CElementwiseOperation
c_element_op
;
bool
has_main_k_block_loop
=
true
;
bool
is_valid
=
false
;
constexpr
Descriptor
(
ADesc
a
,
BDesc
b
,
B1Desc
b1
,
CDesc
c
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
B1ElementwiseOperation
b1_element_op_
,
CElementwiseOperation
c_element_op_
)
:
a_grid_desc_ak0_m_ak1
{
MakeAGridDescriptor_AK0_M_AK1
(
a
)},
b_grid_desc_bk0_n_bk1
{
MakeBGridDescriptor_BK0_N_BK1
(
b
)},
b1_grid_desc_bk0_n_bk1
{
MakeB1GridDescriptor_BK0_N_BK1
(
b1
)},
c_grid_desc_m_n
{
MakeCGridDescriptor_M_N
(
c
)},
block_2_ctile_map
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
)},
c_grid_descriptor_mblock_mperblock_nblock_nperblock
{
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
)},
has_main_k_block_loop
{
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))},
c0_matrix_mask
{
c
.
GetLength
(
I1
)},
a_element_op
{
a_element_op_
},
b_element_op
{
b_element_op_
},
b1_element_op
{
b1_element_op_
},
c_element_op
{
c_element_op_
},
is_valid
{
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_m_n
,
block_2_ctile_map
)
and
IsSupported
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
),
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
),
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
),
b1_grid_desc_bk0_n_bk1
.
GetLength
(
I1
))}
{
}
constexpr
bool
IsValid
()
const
{
return
is_valid
;
}
};
template
<
class
ADesc
,
class
BDesc
,
class
B1Desc
,
class
CDesc
>
static
constexpr
auto
make_descriptor
(
ADesc
a
,
BDesc
b
,
B1Desc
b1
,
CDesc
c
,
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
B1ElementwiseOperation
b1_element_op
=
B1ElementwiseOperation
{},
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{})
{
return
Descriptor
<
ADesc
,
BDesc
,
B1Desc
,
CDesc
>
(
a
,
b
,
b1
,
c
,
a_element_op
,
b_element_op
,
b1_element_op
,
c_element_op
);
}
template
<
class
Desc
>
__device__
static
void
Run
(
const
Desc
&
desc
,
const
float
scale
,
const
ADataType
*
__restrict__
p_a_grid
,
const
ADataType
*
__restrict__
p_b_grid
,
const
ADataType
*
__restrict__
p_b1_grid
,
CDataType
*
__restrict__
p_c_grid
)
{
#ifndef __HIPCC_RTC__
assert
(
desc
.
is_valid
);
#endif
__shared__
char
p_shared_block
[
Desc
::
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
AccElementwiseOperation
acc_element_op
{
scale
};
if
(
desc
.
has_main_k_block_loop
)
{
Desc
::
GridwiseGemm
::
template
Run
<
true
>(
p_a_grid
,
p_b_grid
,
p_b1_grid
,
p_c_grid
,
p_shared_block
,
desc
.
a_element_op
,
desc
.
b_element_op
,
acc_element_op
,
desc
.
b1_element_op
,
desc
.
c_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
b1_grid_desc_bk0_n_bk1
,
desc
.
c_grid_descriptor_mblock_mperblock_nblock_nperblock
,
desc
.
block_2_ctile_map
,
desc
.
c0_matrix_mask
);
}
else
{
Desc
::
GridwiseGemm
::
template
Run
<
false
>(
p_a_grid
,
p_b_grid
,
p_b1_grid
,
p_c_grid
,
p_shared_block
,
desc
.
a_element_op
,
desc
.
b_element_op
,
acc_element_op
,
desc
.
b1_element_op
,
desc
.
c_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
b1_grid_desc_bk0_n_bk1
,
desc
.
c_grid_descriptor_mblock_mperblock_nblock_nperblock
,
desc
.
block_2_ctile_map
,
desc
.
c0_matrix_mask
);
}
}
};
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
d43cd4ad
...
...
@@ -3,8 +3,12 @@
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
...
...
@@ -14,8 +18,6 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
...
...
@@ -35,7 +37,7 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_multiple_d_xdl_cshuffle
(
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
...
...
@@ -225,9 +227,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
a
rray
<
index_t
,
NumDTensor
>&
MRaws
,
const
std
::
a
rray
<
index_t
,
NumDTensor
>&
NRaws
,
const
std
::
a
rray
<
index_t
,
NumDTensor
>&
DsStride
)
static
auto
MakeDsGridDescriptor_M_N
(
const
A
rray
<
index_t
,
NumDTensor
>&
MRaws
,
const
A
rray
<
index_t
,
NumDTensor
>&
NRaws
,
const
A
rray
<
index_t
,
NumDTensor
>&
DsStride
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
...
...
@@ -309,6 +311,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
#ifndef __HIPCC_RTC__
// Argument
struct
Argument
:
public
BaseArgument
{
...
...
@@ -498,6 +501,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
}
};
#endif
static
constexpr
bool
IsSupported
(
index_t
MRaw_
,
index_t
NRaw_
,
index_t
KRaw_
)
{
// check vector load/store
...
...
@@ -578,6 +583,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return
true
;
}
#ifndef __HIPCC_RTC__
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
...
...
@@ -676,11 +682,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
{
PipelineVersion
::
v2
,
"v2"
}};
// clang-format off
str
<<
"DeviceGemmMultipleD_Xdl_CShuffle"
...
...
@@ -709,6 +717,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return
str
.
str
();
}
#endif
template
<
class
ADesc
,
class
BDesc
,
class
DsDesc
,
class
EDesc
>
struct
Descriptor
...
...
@@ -847,7 +856,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
EDataType
*
__restrict__
p_e_grid
)
{
__shared__
char
p_shared_block
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
#ifndef __HIPCC_RTC__
assert
(
desc
.
IsValid
());
#endif
if
(
desc
.
has_main_k_block_loop
)
{
GridwiseGemm
::
template
Run
<
true
>(
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/device/masking_specialization.hpp
View file @
d43cd4ad
...
...
@@ -13,6 +13,7 @@ enum struct MaskingSpecialization
MaskOutUpperTriangle
};
#ifndef __HIPCC_RTC__
inline
std
::
string
getMaskingSpecializationString
(
const
MaskingSpecialization
&
s
)
{
switch
(
s
)
...
...
@@ -22,6 +23,7 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s
default:
return
"Unrecognized specialization!"
;
}
}
#endif
struct
MaskDisabledPredicate
{
...
...
@@ -53,7 +55,7 @@ struct MaskOutUpperTrianglePredicate
template
<
typename
MaskOutPredicate
>
struct
C0MatrixMask_impl
{
__host__
__device__
C0MatrixMask_impl
(
index_t
NRaw
)
__host__
__device__
constexpr
C0MatrixMask_impl
(
index_t
NRaw
)
:
NRaw_
(
NRaw
),
predicate_
(
MaskOutPredicate
{})
{
}
...
...
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
View file @
d43cd4ad
...
...
@@ -430,6 +430,7 @@ struct G_NDHW : public BaseTensorLayout
}
// namespace convolution
#ifndef __HIPCC_RTC__
template
<
typename
Layout
,
typename
std
::
enable_if
<
std
::
is_base_of
<
BaseTensorLayout
,
Layout
>
::
value
,
bool
>::
type
=
false
>
...
...
@@ -438,6 +439,7 @@ std::ostream& operator<<(std::ostream& os, const Layout&)
os
<<
Layout
::
name
;
return
os
;
}
#endif
}
// namespace tensor_layout
}
// namespace ck
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
d43cd4ad
...
...
@@ -340,8 +340,8 @@ struct Bilinear
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
std
::
int8_t
,
std
::
int32_t
,
std
::
int8_t
>
(
std
::
int8_t
&
y
,
const
std
::
int32_t
&
x0
,
const
std
::
int8_t
&
x1
)
const
__host__
__device__
constexpr
void
operator
()
<
int8_t
,
int32_t
,
int8_t
>
(
int8_t
&
y
,
const
int32_t
&
x0
,
const
int8_t
&
x1
)
const
{
y
=
type_convert
<
int8_t
>
(
alpha_
*
type_convert
<
float
>
(
x0
)
+
beta_
*
type_convert
<
float
>
(
x1
));
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
d43cd4ad
...
...
@@ -466,7 +466,7 @@ struct FastGelu
template
<
typename
Y
,
typename
X
>
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
#ifndef __HIPCC_RTC__
template
<
>
__host__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
...
...
@@ -477,7 +477,7 @@ struct FastGelu
const
float
emu
=
exp
(
u
);
y
=
x
/
(
1.
f
+
emu
);
}
#endif
// device code, use lower precision "__ocml_exp_f32" and "rcp"
template
<
>
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
d43cd4ad
...
...
@@ -7,8 +7,10 @@
#include "ck/utility/number.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#ifndef __HIPCC_RTC__
#include <limits>
#include <stdlib.h>
#endif
namespace
ck
{
...
...
@@ -979,7 +981,7 @@ struct BlockToCTileMap_3DGrid_KSplit
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
return
std
::
make_tuple
(
N0
,
M0
,
k_split
);
return
ck
::
make_tuple
(
N0
,
M0
,
k_split
);
}
template
<
typename
TopIdx
>
...
...
@@ -1103,7 +1105,7 @@ struct BlockToCTileMap_GemmStreamK
uint32_t
dp_for_sk_iters
=
k_iters_per_tile
.
get
();
uint32_t
best_sk_score
=
std
::
n
umeric
_l
imits
<
int
>::
m
ax
();
// we need to find the smallest sk iters
ck
::
N
umeric
L
imits
<
int
>::
M
ax
();
// we need to find the smallest sk iters
for
(
uint32_t
tentative_sk_blocks
=
min_sk_tiles
;
tentative_sk_blocks
<
max_sk_tiles
;
tentative_sk_blocks
++
)
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
d43cd4ad
...
...
@@ -475,9 +475,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
template
<
typename
DsLayout
,
GemmSpecialization
GemmSpec
>
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
a
rray
<
index_t
,
NumDTensor
>&
MRaws
,
const
std
::
a
rray
<
index_t
,
NumDTensor
>&
NRaws
,
const
std
::
a
rray
<
index_t
,
NumDTensor
>&
DsStride
)
MakeDsGridDescriptor_M_N
(
const
A
rray
<
index_t
,
NumDTensor
>&
MRaws
,
const
A
rray
<
index_t
,
NumDTensor
>&
NRaws
,
const
A
rray
<
index_t
,
NumDTensor
>&
DsStride
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
...
...
@@ -941,7 +941,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const
index_t
K
,
const
index_t
StrideA
,
const
index_t
StrideB
,
const
std
::
a
rray
<
index_t
,
NumDTensor
>
StrideDs
,
const
A
rray
<
index_t
,
NumDTensor
>
StrideDs
,
const
index_t
StrideE
,
const
Block2ETileMap
&
block_2_etile_map
)
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
View file @
d43cd4ad
...
...
@@ -3,8 +3,10 @@
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <ostream>
#endif
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
...
...
@@ -53,12 +55,15 @@ constexpr auto GridwiseGemmPipeline_Selector()
}
else
{
#ifndef __HIPCC_RTC__
std
::
cerr
<<
"GridwiseGemmPipeline configuration is not available"
<<
std
::
endl
;
#endif
}
}
}
// namespace ck
#ifndef __HIPCC_RTC__
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ck
::
PipelineVersion
&
p
)
{
switch
(
p
)
...
...
@@ -71,3 +76,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p)
}
return
os
;
}
#endif
include/ck/utility/amd_buffer_addressing.hpp
View file @
d43cd4ad
...
...
@@ -1005,6 +1005,7 @@ llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
index_t
offset
,
index_t
aux
)
__asm
(
"llvm.amdgcn.raw.buffer.load.lds"
);
#ifndef __HIPCC_RTC__
template
<
typename
T
,
index_t
NumElemsPerThread
>
__device__
void
amd_direct_load_global_to_lds
(
const
T
*
global_base_ptr
,
const
index_t
global_offset
,
...
...
@@ -1042,5 +1043,6 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
src_resource
,
lds_ptr
,
sizeof
(
uint32_t
),
global_offset_bytes
,
0
,
0
,
0
);
#endif
}
#endif
}
// namespace ck
include/ck/utility/amd_wave_read_first_lane.hpp
View file @
d43cd4ad
...
...
@@ -7,10 +7,12 @@
#include "ck/utility/functional2.hpp"
#include "ck/utility/math.hpp"
#ifndef __HIPCC_RTC__
#include <array>
#include <cstddef>
#include <cstdint>
#include <type_traits>
#endif
namespace
ck
{
namespace
detail
{
...
...
@@ -37,7 +39,7 @@ struct get_carrier<3>
{
using
value_type
=
uint32_t
;
std
::
a
rray
<
std
::
byte
,
3
>
bytes
;
A
rray
<
ck
::
byte
,
3
>
bytes
;
static_assert
(
sizeof
(
bytes
)
<=
sizeof
(
value_type
));
// replacement of host std::copy_n()
...
...
@@ -61,22 +63,22 @@ struct get_carrier<3>
// method to trigger template substitution failure
__device__
carrier
(
const
carrier
&
other
)
noexcept
{
copy_n
(
other
.
bytes
.
begin
(),
bytes
.
s
ize
(),
bytes
.
begin
());
copy_n
(
other
.
bytes
.
begin
(),
bytes
.
S
ize
(),
bytes
.
begin
());
}
public:
__device__
carrier
&
operator
=
(
value_type
value
)
noexcept
{
copy_n
(
reinterpret_cast
<
const
std
::
byte
*>
(
&
value
),
bytes
.
s
ize
(),
bytes
.
begin
());
copy_n
(
reinterpret_cast
<
const
ck
::
byte
*>
(
&
value
),
bytes
.
S
ize
(),
bytes
.
begin
());
return
*
this
;
}
__device__
operator
value_type
()
const
noexcept
{
std
::
byte
result
[
sizeof
(
value_type
)];
ck
::
byte
result
[
sizeof
(
value_type
)];
copy_n
(
bytes
.
begin
(),
bytes
.
s
ize
(),
result
);
copy_n
(
bytes
.
begin
(),
bytes
.
S
ize
(),
result
);
return
*
reinterpret_cast
<
const
value_type
*>
(
result
);
}
...
...
@@ -109,8 +111,8 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value)
{
constexpr
unsigned
object_size
=
sizeof
(
int64_t
);
constexpr
unsigned
second_part_offset
=
object_size
/
2
;
auto
*
const
from_obj
=
reinterpret_cast
<
const
std
::
byte
*>
(
&
value
);
alignas
(
int64_t
)
std
::
byte
to_obj
[
object_size
];
auto
*
const
from_obj
=
reinterpret_cast
<
const
ck
::
byte
*>
(
&
value
);
alignas
(
int64_t
)
ck
::
byte
to_obj
[
object_size
];
using
Sgpr
=
uint32_t
;
...
...
@@ -124,15 +126,15 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value)
template
<
typename
Object
,
typename
=
std
::
enable_if_t
<
std
::
is_class_v
<
Object
>
&&
std
::
is_trivially_copyable_v
<
Object
>>>
typename
=
ck
::
enable_if_t
<
ck
::
is_class_v
<
Object
>
&&
ck
::
is_trivially_copyable_v
<
Object
>>>
__device__
auto
amd_wave_read_first_lane
(
const
Object
&
obj
)
{
using
Size
=
unsigned
;
constexpr
Size
SgprSize
=
4
;
constexpr
Size
ObjectSize
=
sizeof
(
Object
);
auto
*
const
from_obj
=
reinterpret_cast
<
const
std
::
byte
*>
(
&
obj
);
alignas
(
Object
)
std
::
byte
to_obj
[
ObjectSize
];
auto
*
const
from_obj
=
reinterpret_cast
<
const
ck
::
byte
*>
(
&
obj
);
alignas
(
Object
)
ck
::
byte
to_obj
[
ObjectSize
];
constexpr
Size
RemainedSize
=
ObjectSize
%
SgprSize
;
constexpr
Size
CompleteSgprCopyBoundary
=
ObjectSize
-
RemainedSize
;
...
...
include/ck/utility/array.hpp
View file @
d43cd4ad
...
...
@@ -38,6 +38,8 @@ struct Array
}
__host__
__device__
constexpr
const
TData
*
begin
()
const
{
return
&
mData
[
0
];
}
__host__
__device__
constexpr
const
TData
*
end
()
const
{
return
&
mData
[
NSize
];
}
__host__
__device__
constexpr
TData
*
begin
()
{
return
&
mData
[
0
];
}
__host__
__device__
constexpr
TData
*
end
()
{
return
&
mData
[
NSize
];
}
};
// empty Array
...
...
@@ -54,7 +56,7 @@ template <typename X, typename... Xs>
__host__
__device__
constexpr
auto
make_array
(
X
&&
x
,
Xs
&&
...
xs
)
{
using
data_type
=
remove_cvref_t
<
X
>
;
return
Array
<
data_type
,
sizeof
...(
Xs
)
+
1
>
{
std
::
forward
<
X
>
(
x
),
std
::
forward
<
Xs
>
(
xs
)...};
return
Array
<
data_type
,
sizeof
...(
Xs
)
+
1
>
{
ck
::
forward
<
X
>
(
x
),
ck
::
forward
<
Xs
>
(
xs
)...};
}
// make empty array
...
...
include/ck/utility/container_helper.hpp
View file @
d43cd4ad
...
...
@@ -326,14 +326,14 @@ template <typename T, index_t NX, index_t NY>
__host__
__device__
constexpr
auto
container_concat
(
const
Array
<
T
,
NX
>&
ax
,
const
Array
<
T
,
NY
>&
ay
)
{
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
make_array
(
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
ax
,
ay
);
[
&
](
auto
&&
...
zs
)
{
return
make_array
(
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
ax
,
ay
);
}
template
<
typename
...
X
,
typename
...
Y
>
__host__
__device__
constexpr
auto
container_concat
(
const
Tuple
<
X
...
>&
tx
,
const
Tuple
<
Y
...
>&
ty
)
{
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
make_tuple
(
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
tx
,
ty
);
[
&
](
auto
&&
...
zs
)
{
return
make_tuple
(
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
tx
,
ty
);
}
template
<
typename
Container
>
...
...
include/ck/utility/data_type.hpp
View file @
d43cd4ad
...
...
@@ -5,8 +5,25 @@
#include "ck/utility/statically_indexed_array.hpp"
#ifdef __HIPCC_RTC__
/// Definitions from <cstdint>, <cmath> conflict with
/// /opt/rocm/include/hip/amd_detail/amd_hip_vector_types.h.
using
int8_t
=
signed
char
;
using
uint8_t
=
unsigned
char
;
using
int16_t
=
signed
short
;
using
uint16_t
=
unsigned
short
;
using
float_t
=
float
;
#endif // __HIPCC_RTC__
namespace
ck
{
#ifdef __HIPCC_RTC__
using
byte
=
unsigned
char
;
#else
using
std
::
byte
;
#endif
using
bhalf_t
=
ushort
;
using
half_t
=
_Float16
;
using
int4_t
=
_BitInt
(
4
);
...
...
@@ -1060,6 +1077,146 @@ using uint8x16_t = typename vector_type<uint8_t, 16>::type;
using
uint8x32_t
=
typename
vector_type
<
uint8_t
,
32
>::
type
;
using
uint8x64_t
=
typename
vector_type
<
uint8_t
,
64
>::
type
;
#ifdef __HIPCC_RTC__
template
<
typename
T
>
struct
NumericLimits
;
template
<
>
struct
NumericLimits
<
int32_t
>
{
__host__
__device__
static
constexpr
int32_t
Lowest
()
noexcept
{
return
-
2147483647
-
1
;
}
__host__
__device__
static
constexpr
int32_t
Min
()
noexcept
{
return
-
2147483647
-
1
;
}
__host__
__device__
static
constexpr
int32_t
Max
()
noexcept
{
return
2147483647
;
}
__host__
__device__
static
constexpr
int32_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
int32_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
int16_t
>
{
__host__
__device__
static
constexpr
int16_t
Lowest
()
noexcept
{
return
-
32768
;
}
__host__
__device__
static
constexpr
int16_t
Min
()
noexcept
{
return
-
32768
;
}
__host__
__device__
static
constexpr
int16_t
Max
()
noexcept
{
return
32767
;
}
__host__
__device__
static
constexpr
int16_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
int16_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
int8_t
>
{
__host__
__device__
static
constexpr
int8_t
Lowest
()
noexcept
{
return
-
128
;
}
__host__
__device__
static
constexpr
int8_t
Min
()
noexcept
{
return
-
128
;
}
__host__
__device__
static
constexpr
int8_t
Max
()
noexcept
{
return
127
;
}
__host__
__device__
static
constexpr
int8_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
int8_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
uint32_t
>
{
__host__
__device__
static
constexpr
uint32_t
Lowest
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint32_t
Min
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint32_t
Max
()
noexcept
{
return
4294967295U
;
}
__host__
__device__
static
constexpr
uint32_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint32_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
uint16_t
>
{
__host__
__device__
static
constexpr
uint16_t
Lowest
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint16_t
Min
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint16_t
Max
()
noexcept
{
return
65535U
;
}
__host__
__device__
static
constexpr
uint16_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint16_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
float
>
{
static
constexpr
unsigned
int
binary_min
=
0x00800000
;
static
constexpr
unsigned
int
binary_max
=
0x7F7FFFFF
;
static
constexpr
unsigned
int
binary_lowest
=
0xFF7FFFFF
;
static
constexpr
unsigned
int
binary_qnan
=
0xFFC00001
;
static
constexpr
unsigned
int
binary_inf
=
0x7F8000000
;
__host__
__device__
static
constexpr
float
Min
()
{
return
bit_cast
<
float
>
(
binary_min
);
}
__host__
__device__
static
constexpr
float
Max
()
{
return
bit_cast
<
float
>
(
binary_max
);
}
__host__
__device__
static
constexpr
float
Lowest
()
{
return
bit_cast
<
float
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
float
QuietNaN
()
{
return
bit_cast
<
float
>
(
binary_qnan
);
}
__host__
__device__
static
constexpr
float
Infinity
()
{
return
bit_cast
<
float
>
(
binary_inf
);
}
};
template
<
>
struct
NumericLimits
<
half_t
>
{
static
constexpr
unsigned
short
binary_min
=
0x0400
;
static
constexpr
unsigned
short
binary_max
=
0x7BFF
;
static
constexpr
unsigned
short
binary_lowest
=
0xFBFF
;
static
constexpr
unsigned
short
binary_qnan
=
0x7FFF
;
__host__
__device__
static
constexpr
half_t
Min
()
{
return
bit_cast
<
half_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
half_t
Max
()
{
return
bit_cast
<
half_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
half_t
Lowest
()
{
return
bit_cast
<
half_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
half_t
QuietNaN
()
{
return
bit_cast
<
half_t
>
(
binary_qnan
);
}
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
struct
NumericLimits
<
int4_t
>
{
__host__
__device__
static
constexpr
int4_t
Min
()
{
return
int4_t
(
-
8
);
}
__host__
__device__
static
constexpr
int4_t
Max
()
{
return
int4_t
(
7
);
}
__host__
__device__
static
constexpr
int4_t
Lowest
()
{
return
int4_t
(
-
8
);
}
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
struct
NumericLimits
<
f8_t
>
{
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000
static
constexpr
uint8_t
binary_max
=
0x77
;
// 0b01110111
static
constexpr
uint8_t
binary_lowest
=
0xF7
;
// 0b11110111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
__host__
__device__
static
constexpr
f8_t
Min
()
{
return
bit_cast
<
f8_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_t
Max
()
{
return
bit_cast
<
f8_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_t
Lowest
()
{
return
bit_cast
<
f8_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
bit_cast
<
f8_t
>
(
binary_qnan
);
}
};
#else
template
<
typename
T
>
struct
NumericLimits
{
...
...
@@ -1151,6 +1308,7 @@ struct NumericLimits<bf8_t>
__host__
__device__
static
constexpr
bf8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
};
#endif
template
<
typename
T
>
struct
NumericUtils
...
...
include/ck/utility/enable_if.hpp
View file @
d43cd4ad
...
...
@@ -4,11 +4,26 @@
#pragma once
namespace
ck
{
#ifdef __HIPCC_RTC__
template
<
bool
B
,
class
T
=
void
>
struct
enable_if
{
};
template
<
class
T
>
struct
enable_if
<
true
,
T
>
{
using
type
=
T
;
};
template
<
bool
B
,
class
T
=
void
>
using
enable_if_t
=
typename
enable_if
<
B
,
T
>::
type
;
#else
template
<
bool
B
,
typename
T
=
void
>
using
enable_if
=
std
::
enable_if
<
B
,
T
>
;
template
<
bool
B
,
typename
T
=
void
>
using
enable_if_t
=
typename
std
::
enable_if
<
B
,
T
>::
type
;
#endif
}
// namespace ck
include/ck/utility/env.hpp
View file @
d43cd4ad
...
...
@@ -183,3 +183,7 @@ void UpdateEnvVar(EnvVar, const std::string_view& val)
}
}
// namespace ck
// environment variable to enable logging:
// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED
CK_DECLARE_ENV_VAR_BOOL
(
CK_LOGGING
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment