Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6252d207
Commit
6252d207
authored
Sep 13, 2024
by
carlushuang
Browse files
Merge remote-tracking branch 'origin/develop' into ck_tile/fav3_fwd_sept
parents
eed60199
e07f1108
Changes
51
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1863 additions
and
142 deletions
+1863
-142
Dockerfile
Dockerfile
+2
-0
Jenkinsfile
Jenkinsfile
+2
-2
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+8
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp
...nsor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp
+453
-0
include/ck/tensor_operation/gpu/device/impl/device_avgpool2d_bwd_nhwc_nhwc.hpp
...ration/gpu/device/impl/device_avgpool2d_bwd_nhwc_nhwc.hpp
+523
-0
include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
...operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
+349
-74
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+28
-1
include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp
+38
-14
include/ck/utility/amd_smfmac.hpp
include/ck/utility/amd_smfmac.hpp
+14
-12
include/ck/utility/reduction_operator.hpp
include/ck/utility/reduction_operator.hpp
+165
-10
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+15
-15
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+1
-1
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp
...gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp
+1
-1
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...lock_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+6
-6
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp
...gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp
+1
-1
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
...ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
+4
-4
library/include/ck/library/tensor_operation_instance/gpu/avg_pool2d_bwd.hpp
.../library/tensor_operation_instance/gpu/avg_pool2d_bwd.hpp
+80
-0
library/include/ck/library/tensor_operation_instance/gpu/max_pool_bwd.hpp
...ck/library/tensor_operation_instance/gpu/max_pool_bwd.hpp
+12
-1
library/include/ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp
...e/ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp
+153
-0
library/src/tensor_operation_instance/gpu/avg_pool2d_bwd/CMakeLists.txt
...nsor_operation_instance/gpu/avg_pool2d_bwd/CMakeLists.txt
+8
-0
No files found.
Dockerfile
View file @
6252d207
...
...
@@ -130,6 +130,8 @@ ENV compiler_commit=$compiler_commit
RUN
sh
-c
"echo compiler version = '
$compiler_version
'"
RUN
sh
-c
"echo compiler commit = '
$compiler_commit
'"
ARG
DISABLE_CACHE=0
RUN if
(
[
"
$compiler_version
"
=
"amd-staging"
]
||
[
"
$compiler_version
"
=
"amd-mainline-open"
]
)
&&
[
"
$compiler_commit
"
=
""
]
;
then
\
git clone
-b
"
$compiler_version
"
https://github.com/ROCm/llvm-project.git
&&
\
cd
llvm-project
&&
mkdir
build
&&
cd
build
&&
\
...
...
Jenkinsfile
View file @
6252d207
...
...
@@ -94,7 +94,7 @@ def getDockerImage(Map conf=[:]){
env
.
DOCKER_BUILDKIT
=
1
def
prefixpath
=
conf
.
get
(
"prefixpath"
,
"/opt/rocm"
)
def
no_cache
=
conf
.
get
(
"no_cache"
,
false
)
def
dockerArgs
=
"--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
def
dockerArgs
=
"--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}'
--build-arg DISABLE_CACHE='git rev-parse ${params.COMPILER_VERSION}'
"
if
(
no_cache
)
{
dockerArgs
=
dockerArgs
+
" --no-cache "
...
...
@@ -124,7 +124,7 @@ def buildDocker(install_prefix){
checkout
scm
def
image_name
=
getDockerImageName
()
echo
"Building Docker for ${image_name}"
def
dockerArgs
=
"--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
def
dockerArgs
=
"--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}'
--build-arg DISABLE_CACHE='git rev-parse ${params.COMPILER_VERSION}'
"
echo
"Build Args: ${dockerArgs}"
try
{
...
...
example/01_gemm/run_gemm_example.inc
View file @
6252d207
...
...
@@ -305,6 +305,14 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}
#endif
}
else
{
// When the Problem Type and Problem Size does not fit.
std
::
cerr
<<
gemm
.
GetTypeString
()
<<
": the instance does not support the problem config."
<<
std
::
endl
;
return
true
;
}
std
::
size_t
flop
=
2_
uz
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp
0 → 100644
View file @
6252d207
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_avgpool2d_bwd_nhwc_nhwc.hpp
0 → 100644
View file @
6252d207
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
View file @
6252d207
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
6252d207
...
...
@@ -355,12 +355,39 @@ struct UnaryDivide
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
x
/
type_convert
<
T
>
(
divider_
);
};
template
<
>
__host__
__device__
void
operator
()
<
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
{
float
x_
=
type_convert
<
float
>
(
x
);
float
divider_f_
=
type_convert
<
float
>
(
divider_
);
y
=
type_convert
<
half_t
>
(
x_
/
divider_f_
);
};
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
float
x_
=
type_convert
<
float
>
(
x
);
float
divider_f_
=
type_convert
<
float
>
(
divider_
);
y
=
type_convert
<
bhalf_t
>
(
x_
/
divider_f_
);
};
template
<
>
__host__
__device__
void
operator
()
<
f8_t
>
(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
float
x_
=
type_convert
<
float
>
(
x
);
float
divider_f_
=
type_convert
<
float
>
(
divider_
);
y
=
type_convert
<
f8_t
>
(
x_
/
divider_f_
);
};
int32_t
divider_
=
1
;
};
...
...
include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp
View file @
6252d207
...
...
@@ -35,10 +35,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32f16>
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
idx_part
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
index_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_16x16x32f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
intrin_smfmac_f32_16x16x32f16
<
MPerXdlops
,
NPerXdlops
>::
Run
<
FloatC
,
idx_part
>
(
a
,
b
,
idx
,
reg_c
);
}
};
...
...
@@ -57,10 +63,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16f16>
static
constexpr
index_t
k_per_blk
=
16
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
idx_part
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
index_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_32x32x16f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
intrin_smfmac_f32_32x32x16f16
<
MPerXdlops
,
NPerXdlops
>::
Run
<
FloatC
,
idx_part
>
(
a
,
b
,
idx
,
reg_c
);
}
};
...
...
@@ -79,10 +91,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32bf16>
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
idx_part
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
index_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_16x16x32bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
intrin_smfmac_f32_16x16x32bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
<
FloatC
,
idx_part
>
(
a
,
b
,
idx
,
reg_c
);
}
};
...
...
@@ -101,10 +119,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16bf16>
static
constexpr
index_t
k_per_blk
=
16
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
idx_part
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
index_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_32x32x16bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
intrin_smfmac_f32_32x32x16bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
<
FloatC
,
idx_part
>
(
a
,
b
,
idx
,
reg_c
);
}
};
...
...
@@ -305,8 +329,8 @@ struct SparseXdlopsGemm
"base base_type must be half or bfloat16!"
);
static_for
<
0
,
KPack
/
smfmac_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
smfmac_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
idx
[
k
],
p_c_thread
);
smfmac_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
,
k
%
4
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
idx
[
k
/
4
],
p_c_thread
);
});
}
...
...
include/ck/utility/amd_smfmac.hpp
View file @
6252d207
...
...
@@ -9,16 +9,18 @@ namespace ck {
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_smfmac_f32_16x16x32f16
;
// for every smfmac instruction if CBSZ[1:0]=0, ABID[1:0] selects one of four 8-bit sets of sparse
// indices from reg_idx
template
<
>
struct
intrin_smfmac_f32_16x16x32f16
<
16
,
16
>
{
template
<
class
FloatC
>
template
<
class
FloatC
,
index_t
abid
=
0
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
in
t32
_t
&
reg_idx
,
FloatC
&
reg_c
)
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
in
dex
_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
abid
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
...
...
@@ -34,13 +36,13 @@ struct intrin_smfmac_f32_16x16x32bf16;
template
<
>
struct
intrin_smfmac_f32_16x16x32bf16
<
16
,
16
>
{
template
<
class
FloatC
>
template
<
class
FloatC
,
index_t
abid
=
0
>
__device__
static
void
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
in
t32
_t
&
reg_idx
,
FloatC
&
reg_c
)
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
in
dex
_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
abid
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
...
...
@@ -56,13 +58,13 @@ struct intrin_smfmac_f32_32x32x16f16;
template
<
>
struct
intrin_smfmac_f32_32x32x16f16
<
32
,
32
>
{
template
<
class
FloatC
>
template
<
class
FloatC
,
index_t
abid
=
0
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
in
t32
_t
&
reg_idx
,
FloatC
&
reg_c
)
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
in
dex
_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
abid
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
...
...
@@ -78,13 +80,13 @@ struct intrin_smfmac_f32_32x32x16bf16;
template
<
>
struct
intrin_smfmac_f32_32x32x16bf16
<
32
,
32
>
{
template
<
class
FloatC
>
template
<
class
FloatC
,
index_t
abid
=
0
>
__device__
static
void
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
in
t32
_t
&
reg_idx
,
FloatC
&
reg_c
)
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
in
dex
_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
abid
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
...
...
include/ck/utility/reduction_operator.hpp
View file @
6252d207
...
...
@@ -52,12 +52,28 @@ struct Add
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
half
_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8
_t
>::
value
,
"The data type is not supported by the Add accumulator!"
);
a
=
a
+
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
a
=
type_convert
<
f8_t
>
(
a_
+
b_
);
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
a
=
type_convert
<
half_t
>
(
a_
+
b_
);
}
__host__
__device__
inline
constexpr
void
operator
()(
bhalf_t
&
a
,
bhalf_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
...
...
@@ -112,12 +128,28 @@ struct Mul
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
half
_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8
_t
>::
value
,
"The data type is not supported by the Mul accumulator!"
);
a
=
a
*
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
a
=
type_convert
<
f8_t
>
(
a_
*
b_
);
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
a
=
type_convert
<
half_t
>
(
a_
*
b_
);
}
__host__
__device__
inline
constexpr
void
operator
()(
bhalf_t
&
a
,
bhalf_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
...
...
@@ -137,6 +169,16 @@ struct Max
float
val
=
NumericLimits
<
float
>::
Lowest
();
return
type_convert
<
bhalf_t
>
(
val
);
}
if
constexpr
(
is_same_v
<
T
,
f8_t
>
)
{
float
val
=
NumericLimits
<
float
>::
Lowest
();
return
type_convert
<
f8_t
>
(
val
);
}
if
constexpr
(
is_same_v
<
T
,
half_t
>
)
{
float
val
=
NumericLimits
<
float
>::
Lowest
();
return
type_convert
<
half_t
>
(
val
);
}
else
{
return
NumericLimits
<
T
>::
Lowest
();
...
...
@@ -154,8 +196,7 @@ struct Max
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
if
(
a
<
b
)
...
...
@@ -171,12 +212,29 @@ struct Max
a
=
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
a
=
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
a
=
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
if
(
a
<
b
)
...
...
@@ -197,6 +255,30 @@ struct Max
changed
=
true
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
{
a
=
b
;
changed
=
true
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
{
a
=
b
;
changed
=
true
;
}
}
};
struct
Min
...
...
@@ -209,6 +291,16 @@ struct Min
float
val
=
NumericLimits
<
float
>::
Max
();
return
type_convert
<
bhalf_t
>
(
val
);
}
else
if
constexpr
(
is_same_v
<
T
,
half_t
>
)
{
float
val
=
NumericLimits
<
float
>::
Max
();
return
type_convert
<
half_t
>
(
val
);
}
else
if
constexpr
(
is_same_v
<
T
,
f8_t
>
)
{
float
val
=
NumericLimits
<
float
>::
Max
();
return
type_convert
<
f8_t
>
(
val
);
}
else
{
return
NumericLimits
<
T
>::
Max
();
...
...
@@ -227,8 +319,7 @@ struct Min
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Min accumulator!"
);
if
(
a
>
b
)
...
...
@@ -244,6 +335,24 @@ struct Min
a
=
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
>
b_
)
a
=
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
>
b_
)
a
=
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
...
...
@@ -270,6 +379,30 @@ struct Min
changed
=
true
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
>
b_
)
{
a
=
b
;
changed
=
true
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
>
b_
)
{
a
=
b
;
changed
=
true
;
}
}
};
struct
AMax
...
...
@@ -299,6 +432,15 @@ struct AMax
a
=
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
a
=
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
...
...
@@ -313,6 +455,18 @@ struct AMax
changed
=
true
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
{
a
=
b
;
changed
=
true
;
}
}
};
template
<
typename
T
>
...
...
@@ -352,7 +506,8 @@ struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Set,
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
||
is_same
<
DataType
,
half_t
>::
value
||
is_same
<
DataType
,
bhalf_t
>::
value
||
is_same
<
DataType
,
int8_t
>::
value
||
is_same
<
DataType
,
int32_t
>::
value
;
is_same
<
DataType
,
int8_t
>::
value
||
is_same
<
DataType
,
int32_t
>::
value
||
is_same
<
DataType
,
f8_t
>::
value
;
};
template
<
typename
DataType
>
...
...
@@ -361,7 +516,7 @@ struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Add,
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
||
is_same
<
DataType
,
half_t
>::
value
||
is_same
<
DataType
,
int8_t
>::
value
||
is_same
<
DataType
,
int32_t
>::
value
;
is_same
<
DataType
,
int32_t
>::
value
||
is_same
<
DataType
,
f8_t
>::
value
;
};
}
// namespace reduce
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
6252d207
...
...
@@ -29,9 +29,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kM0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK0
>
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
...
...
@@ -62,9 +62,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK1
>
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
...
...
@@ -94,9 +94,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kM0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK2
>
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK2
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
>>
;
...
...
@@ -127,9 +127,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kN0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK3
>
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK3
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
>>
;
...
...
@@ -159,9 +159,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
BlockTile
::
kM0
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
BlockTile
::
kK4
>
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK4
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
>>
;
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
6252d207
...
...
@@ -25,7 +25,7 @@ struct GemmKernel
using
LayoutA
=
remove_cvref_t
<
LayoutA_
>
;
using
LayoutB
=
remove_cvref_t
<
LayoutB_
>
;
using
LayoutC
=
remove_cvref_t
<
LayoutC_
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
Kernel
BlockSize
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
k
BlockSize
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
...
...
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
6252d207
...
...
@@ -19,7 +19,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
Kernel
BlockSize
=
Problem
::
Kernel
BlockSize
;
static
constexpr
index_t
k
BlockSize
=
Problem
::
k
BlockSize
;
static
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
...
...
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
6252d207
...
...
@@ -195,7 +195,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
Kernel
BlockSize
=
Problem
::
Kernel
BlockSize
;
constexpr
index_t
k
BlockSize
=
Problem
::
k
BlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
...
...
@@ -204,7 +204,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
#if 1 // coalesce reading for each blocks
constexpr
index_t
M1
=
Kernel
BlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
k
BlockSize
/
get_warp_size
();
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
...
...
@@ -217,7 +217,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
#else // coalesce reading for each warps
constexpr
index_t
M0
=
Kernel
BlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
k
BlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
kMPerBlock
/
(
M2
*
M0
);
return
make_static_tile_distribution
(
...
...
@@ -235,7 +235,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
Kernel
BlockSize
=
Problem
::
Kernel
BlockSize
;
constexpr
index_t
k
BlockSize
=
Problem
::
k
BlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
...
...
@@ -244,7 +244,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
#if 1 // coalesce reading for each blocks
constexpr
index_t
N1
=
Kernel
BlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
k
BlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
...
...
@@ -257,7 +257,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
#else // coalesce reading for each warps
constexpr
index_t
N0
=
Kernel
BlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
k
BlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
kNPerBlock
/
(
N2
*
N0
);
return
make_static_tile_distribution
(
...
...
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp
View file @
6252d207
...
...
@@ -19,7 +19,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV2
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
Kernel
BlockSize
=
Problem
::
Kernel
BlockSize
;
static
constexpr
index_t
k
BlockSize
=
Problem
::
k
BlockSize
;
static
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
...
...
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
View file @
6252d207
...
...
@@ -23,7 +23,7 @@ struct BlockGemmPipelineProblem
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
index_t
Kernel
BlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
index_t
k
BlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadC
=
kPadC_
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/avg_pool2d_bwd.hpp
0 → 100644
View file @
6252d207
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef CK_ENABLE_BF16
void
add_device_avgpool_2D_bwd_nhwc_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceAvgPoolBwd
<
2
,
BF16
,
BF16
,
NHWC
,
NHWC
>>>&
);
#endif
#ifdef CK_ENABLE_FP16
void
add_device_avgpool_2D_bwd_nhwc_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceAvgPoolBwd
<
2
,
F16
,
F16
,
NHWC
,
NHWC
>>>&
);
#endif
#ifdef CK_ENABLE_FP8
void
add_device_avgpool_2D_bwd_nhwc_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceAvgPoolBwd
<
2
,
F8
,
F8
,
NHWC
,
NHWC
>>>&
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_avgpool_2D_bwd_nhwc_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceAvgPoolBwd
<
2
,
F32
,
F32
,
NHWC
,
NHWC
>>>&
);
#endif
#ifdef CK_ENABLE_INT8
void
add_device_avgpool_2D_bwd_nhwc_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceAvgPoolBwd
<
2
,
I8
,
I8
,
NHWC
,
NHWC
>>>&
);
#endif
template
<
typename
DOutDataType
,
typename
DInDataType
,
typename
InLayout
,
typename
OutLayout
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceAvgPoolBwd
<
2
,
DOutDataType
,
DInDataType
,
InLayout
,
OutLayout
>>
{
using
DeviceOp
=
DeviceAvgPoolBwd
<
2
,
DOutDataType
,
DInDataType
,
InLayout
,
OutLayout
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
InLayout
,
NHWC
>
&&
is_same_v
<
OutLayout
,
NHWC
>
)
{
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
DOutDataType
,
F16
>
&&
is_same_v
<
DInDataType
,
F16
>
)
add_device_avgpool_2D_bwd_nhwc_f16_instances
(
op_ptrs
);
#endif
#ifdef CK_ENABLE_BF16
else
if
constexpr
(
is_same_v
<
DOutDataType
,
BF16
>
&&
is_same_v
<
DInDataType
,
BF16
>
)
add_device_avgpool_2D_bwd_nhwc_bf16_instances
(
op_ptrs
);
#endif
#ifdef CK_ENABLE_FP32
else
if
constexpr
(
is_same_v
<
DOutDataType
,
F32
>
&&
is_same_v
<
DInDataType
,
F32
>
)
add_device_avgpool_2D_bwd_nhwc_f32_instances
(
op_ptrs
);
#endif
#ifdef CK_ENABLE_FP8
else
if
constexpr
(
is_same_v
<
DOutDataType
,
F8
>
&&
is_same_v
<
DInDataType
,
F8
>
)
add_device_avgpool_2D_bwd_nhwc_f8_instances
(
op_ptrs
);
#endif
#ifdef CK_ENABLE_INT8
else
if
constexpr
(
is_same_v
<
DOutDataType
,
I8
>
&&
is_same_v
<
DInDataType
,
I8
>
)
add_device_avgpool_2D_bwd_nhwc_int8_instances
(
op_ptrs
);
#endif
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/max_pool_bwd.hpp
View file @
6252d207
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -23,6 +23,11 @@ void add_device_maxpool_bwd_bf16_instances(
void
add_device_maxpool_bwd_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceMaxPoolBwd
<
F32
,
I32
,
F32
>>>&
);
#endif
#ifdef CK_ENABLE_INT8
void
add_device_maxpool_bwd_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceMaxPoolBwd
<
I8
,
I32
,
I8
>>>&
);
#endif
template
<
typename
DOutDataType
,
typename
IndexDataType
,
typename
DInDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceMaxPoolBwd
<
DOutDataType
,
IndexDataType
,
DInDataType
>>
...
...
@@ -32,6 +37,7 @@ struct DeviceOperationInstanceFactory<
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
DOutDataType
,
F16
>
&&
is_same_v
<
DInDataType
,
F16
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
...
...
@@ -47,6 +53,11 @@ struct DeviceOperationInstanceFactory<
is_same_v
<
IndexDataType
,
I32
>
)
add_device_maxpool_bwd_f32_instances
(
op_ptrs
);
#endif
#ifdef CK_ENABLE_INT8
else
if
constexpr
(
is_same_v
<
DOutDataType
,
I8
>
&&
is_same_v
<
DInDataType
,
I8
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
add_device_maxpool_bwd_int8_instances
(
op_ptrs
);
#endif
return
op_ptrs
;
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp
0 → 100644
View file @
6252d207
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
static
constexpr
auto
InOutRank
=
4
;
static
constexpr
auto
WindowRank
=
2
;
static
constexpr
auto
MaxOp
=
ck
::
ReduceTensorOp
::
MAX
;
static
constexpr
auto
AvgOp
=
ck
::
ReduceTensorOp
::
AVG
;
#ifdef CK_ENABLE_FP16
// FP16
void
add_device_pool2d_fwd_nhwc_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F16
,
F16
,
I32
,
NHWC
,
NHWC
,
MaxOp
,
false
>>>&
);
void
add_device_pool2d_fwd_nhwc_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F16
,
F16
,
I32
,
NHWC
,
NHWC
,
AvgOp
,
false
>>>&
);
// FP16 - return index
void
add_device_pool2d_fwd_nhwc_index_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F16
,
F16
,
I32
,
NHWC
,
NHWC
,
MaxOp
,
true
>>>&
);
#endif
#ifdef CK_ENABLE_BF16
// BF16
void
add_device_pool2d_fwd_nhwc_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
BF16
,
BF16
,
I32
,
NHWC
,
NHWC
,
MaxOp
,
false
>>>&
);
void
add_device_pool2d_fwd_nhwc_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
BF16
,
BF16
,
I32
,
NHWC
,
NHWC
,
AvgOp
,
false
>>>&
);
// BF16 - return index
void
add_device_pool2d_fwd_nhwc_index_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
BF16
,
BF16
,
I32
,
NHWC
,
NHWC
,
MaxOp
,
true
>>>&
);
#endif
#ifdef CK_ENABLE_FP32
// FP32
void
add_device_pool2d_fwd_nhwc_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F32
,
F32
,
I32
,
NHWC
,
NHWC
,
MaxOp
,
false
>>>&
);
void
add_device_pool2d_fwd_nhwc_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F32
,
F32
,
I32
,
NHWC
,
NHWC
,
AvgOp
,
false
>>>&
);
// FP32 - return index
void
add_device_pool2d_fwd_nhwc_index_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F32
,
F32
,
I32
,
NHWC
,
NHWC
,
MaxOp
,
true
>>>&
);
#endif
template
<
typename
InDataType
,
typename
OutDataType
,
typename
IndexDataType
,
typename
InLayout
,
typename
OutLayout
,
ck
::
ReduceTensorOp
ReduceOpId
,
bool
OutputIndex
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DevicePoolFwd
<
InOutRank
,
WindowRank
,
InDataType
,
OutDataType
,
IndexDataType
,
InLayout
,
OutLayout
,
ReduceOpId
,
OutputIndex
>>
{
using
DeviceOp
=
DevicePoolFwd
<
InOutRank
,
WindowRank
,
InDataType
,
OutDataType
,
IndexDataType
,
InLayout
,
OutLayout
,
ReduceOpId
,
OutputIndex
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
InLayout
,
NHWC
>
&&
is_same_v
<
OutLayout
,
NHWC
>
)
{
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
{
if
constexpr
(
OutputIndex
&&
ReduceOpId
==
MaxOp
)
{
add_device_pool2d_fwd_nhwc_index_f16_instances
(
op_ptrs
);
}
else
{
add_device_pool2d_fwd_nhwc_f16_instances
(
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_BF16
else
if
constexpr
(
is_same_v
<
InDataType
,
BF16
>
&&
is_same_v
<
OutDataType
,
BF16
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
{
if
constexpr
(
OutputIndex
&&
ReduceOpId
==
MaxOp
)
{
add_device_pool2d_fwd_nhwc_index_bf16_instances
(
op_ptrs
);
}
else
{
add_device_pool2d_fwd_nhwc_bf16_instances
(
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_FP32
else
if
constexpr
(
is_same_v
<
InDataType
,
F32
>
&&
is_same_v
<
OutDataType
,
F32
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
{
if
constexpr
(
OutputIndex
&&
ReduceOpId
==
MaxOp
)
{
add_device_pool2d_fwd_nhwc_index_f32_instances
(
op_ptrs
);
}
else
{
add_device_pool2d_fwd_nhwc_f32_instances
(
op_ptrs
);
}
}
#endif
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/avg_pool2d_bwd/CMakeLists.txt
0 → 100644
View file @
6252d207
set
(
DEVICE_AVGPOOL_2D_BWD_INSTANCES
)
list
(
APPEND DEVICE_AVGPOOL_2D_BWD_INSTANCES device_avg_pool2d_bwd_nhwc_bf16_instance.cpp
device_avg_pool2d_bwd_nhwc_f16_instance.cpp
device_avg_pool2d_bwd_nhwc_f32_instance.cpp
device_avg_pool2d_bwd_nhwc_f8_instance.cpp
device_avg_pool2d_bwd_nhwc_int8_instance.cpp
)
add_instance_library
(
device_avg_pool2d_bwd_instance
${
DEVICE_AVGPOOL_2D_BWD_INSTANCES
}
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment