Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
0475a327
Commit
0475a327
authored
Nov 04, 2024
by
dummycoderfe
Browse files
Merge branch 'ck_tile/layernorm2d_fwd_optimize' into ck_tile/ln_add_cache_clear
parents
c9b961ab
27ff3dec
Changes
267
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1560 additions
and
166 deletions
+1560
-166
include/ck_tile/ops/common/generic_2d_block_shape.hpp
include/ck_tile/ops/common/generic_2d_block_shape.hpp
+3
-4
include/ck_tile/ops/elementwise.hpp
include/ck_tile/ops/elementwise.hpp
+8
-0
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
.../ck_tile/ops/elementwise/unary_element_wise_operation.hpp
+1163
-0
include/ck_tile/ops/epilogue.hpp
include/ck_tile/ops/epilogue.hpp
+2
-0
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
+17
-11
include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp
include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp
+188
-0
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+1
-0
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+4
-4
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+4
-4
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
...mha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
+12
-11
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
+12
-11
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+14
-13
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
...le/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
+11
-11
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
+12
-11
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+9
-6
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
+19
-3
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+3
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp
.../ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp
+1
-1
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
...ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
+15
-15
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+62
-61
No files found.
include/ck_tile/ops/
layernorm2d/kernel/layernorm2d_fwd
_shape.hpp
→
include/ck_tile/ops/
common/generic_2d_block
_shape.hpp
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
/*
// clang-format off
...
...
@@ -42,7 +41,7 @@ template <typename BlockTile_, // block size, seq<M, N>
typename
Vector_
,
// contiguous pixels(vector size) along seq<M, N>
index_t
BlockSize_
=
warpSize
*
reduce_on_sequence
(
WarpPerBlock_
{}
,
multiplies
{}
,
number
<
1
>{})
>
struct
Layernorm2d
Shape
struct
Generic2dBlock
Shape
{
// block size
static
constexpr
index_t
Block_M
=
BlockTile_
::
at
(
number
<
0
>
{});
...
...
include/ck_tile/ops/elementwise.hpp
0 → 100644
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
0 → 100644
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <type_traits>
namespace
ck_tile
{
namespace
element_wise
{
#if 0
struct PassThroughPack2
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::half2_t& y, const ck_tile::f8x2_t& x) const
{
auto t = type_convert<float2_t>(x);
y = type_convert<half2_t>(t);
}
constexpr const static bool is_pack2_invocable = true;
};
#endif
struct
PassThrough
{
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST_DEVICE
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
double
,
double
>
(
double
&
y
,
const
double
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
double
>
(
float
&
y
,
const
double
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
double
,
float
>
(
double
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
double
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
>
(
ck_tile
::
fp16_t
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
float
>
(
ck_tile
::
fp16_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
fp16_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
>
(
ck_tile
::
bf16_t
&
y
,
const
ck_tile
::
bf16_t
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
int32_t
,
int32_t
>
(
int32_t
&
y
,
const
int32_t
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
bf16_t
,
float
>
(
ck_tile
::
bf16_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
bf16_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
ck_tile
::
bf16_t
>
(
float
&
y
,
const
ck_tile
::
bf16_t
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
bf16_t
,
ck_tile
::
fp16_t
>
(
ck_tile
::
bf16_t
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
bf16_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
ck_tile
::
fp16_t
>
(
float
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
int8_t
>
(
ck_tile
::
fp16_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
fp16_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
bf16_t
,
int8_t
>
(
ck_tile
::
bf16_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
bf16_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
uint8_t
,
uint8_t
>
(
uint8_t
&
y
,
const
uint8_t
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
int8_t
,
int32_t
>
(
int8_t
&
y
,
const
int32_t
&
x
)
const
{
y
=
type_convert
<
int8_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
int32_t
,
int8_t
>
(
int32_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
type_convert
<
int32_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
int8_t
,
float
>
(
int8_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
int8_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
int8_t
>
(
float
&
y
,
const
int8_t
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
int4_t
,
int4_t
>
(
int4_t
&
y
,
const
int4_t
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
int4_t
,
int
>
(
int4_t
&
y
,
const
int
&
x
)
const
{
y
=
type_convert
<
int4_t
>
(
x
);
}
#endif
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
>
(
ck_tile
::
fp8_t
&
y
,
const
ck_tile
::
fp8_t
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
ck_tile
::
fp8_t
>
(
float
&
y
,
const
ck_tile
::
fp8_t
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp8_t
,
float
>
(
ck_tile
::
fp8_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
fp8_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
>
(
ck_tile
::
fp16_t
&
y
,
const
ck_tile
::
fp8_t
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
fp16_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp8_t
,
ck_tile
::
fp16_t
>
(
ck_tile
::
fp8_t
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
fp8_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
bf8_t
,
ck_tile
::
bf8_t
>
(
ck_tile
::
bf8_t
&
y
,
const
ck_tile
::
bf8_t
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
ck_tile
::
bf8_t
>
(
float
&
y
,
const
ck_tile
::
bf8_t
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
bf8_t
,
float
>
(
ck_tile
::
bf8_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
bf8_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
ck_tile
::
bf8_t
>
(
ck_tile
::
fp16_t
&
y
,
const
ck_tile
::
bf8_t
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
fp16_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
bf8_t
,
ck_tile
::
fp16_t
>
(
ck_tile
::
bf8_t
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
y
=
ck_tile
::
type_convert
<
ck_tile
::
bf8_t
>
(
x
);
}
};
#if 0
struct UnaryConvert
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
y = type_convert<Y>(x);
}
};
struct ConvertBF16RTN
{
// convert to bf16 using round to nearest (rtn)
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(std::is_same_v<Y, ck_tile::bf16_t>, "Data type is not supported by this operation!");
// check X datatype
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = bf16_convert_rtn<Y>(x);
}
};
struct ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(std::is_same_v<Y, ck_tile::fp8_t> || std::is_same_v<Y, ck_tile::bf8_t>,
"Data type is not supported by this operation!");
// check X datatype
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = f8_convert_sr<Y>(x);
}
};
struct ConvertF8RNE
{
// convert to fp8 using rounding to nearest even
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(std::is_same_v<Y, ck_tile::fp8_t> || std::is_same_v<Y, ck_tile::bf8_t>,
"Data type is not supported by this operation!");
// check X datatype
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = f8_convert_rne<Y>(x);
}
};
#endif
struct
Scale
{
CK_TILE_HOST_DEVICE
Scale
(
float
scale
=
1.
f
)
:
scale_
(
scale
)
{}
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST_DEVICE
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
y
=
ck_tile
::
type_convert
<
Y
>
(
ck_tile
::
type_convert
<
float
>
(
x
)
*
scale_
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
>
(
ck_tile
::
fp16_t
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
y
=
ck_tile
::
type_convert
<
ck_tile
::
fp16_t
>
(
scale_
)
*
x
;
};
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
>
(
ck_tile
::
bf16_t
&
y
,
const
ck_tile
::
bf16_t
&
x
)
const
{
const
float
x_tmp
=
ck_tile
::
type_convert
<
float
>
(
x
);
const
float
y_tmp
=
scale_
*
x_tmp
;
y
=
ck_tile
::
type_convert
<
ck_tile
::
bf16_t
>
(
y_tmp
);
};
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
scale_
*
x
;
};
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
double
,
double
>
(
double
&
y
,
const
double
&
x
)
const
{
y
=
scale_
*
x
;
};
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
ck_tile
::
type_convert
<
int8_t
>
(
scale_
*
ck_tile
::
type_convert
<
float
>
(
x
));
};
float
scale_
;
};
struct
ScaleAndResetNaNToMinusInfinity
{
CK_TILE_HOST_DEVICE
ScaleAndResetNaNToMinusInfinity
(
float
scale
)
:
scale_
(
scale
)
{}
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST_DEVICE
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
ck_tile
::
isnan
(
x
)
?
-
numeric
<
float
>::
infinity
()
:
scale_
*
x
;
};
float
scale_
;
};
struct
UnaryDivide
{
CK_TILE_HOST_DEVICE
UnaryDivide
(
const
int32_t
divider
=
1
)
:
divider_
(
divider
)
{}
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
x
/
type_convert
<
T
>
(
divider_
);
};
int32_t
divider_
=
1
;
};
struct
UnarySquare
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
std
::
is_same_v
<
T
,
int4_t
>
#endif
,
"Data type is not supported by this operation!"
);
y
=
x
*
x
;
};
};
struct
UnaryAbs
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
abs
(
x
);
};
};
struct
UnarySqrt
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
sqrt
(
x
);
};
};
struct
Relu
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
,
"Data type is not supported by this operation!"
);
y
=
x
>
0
?
x
:
0
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()(
ck_tile
::
bf16_t
&
y
,
const
ck_tile
::
bf16_t
&
x
)
const
{
float
x_f32
=
ck_tile
::
type_convert
<
float
>
(
x
);
float
y_f32
=
x_f32
>
0
?
x_f32
:
0
;
y
=
ck_tile
::
type_convert
<
ck_tile
::
bf16_t
>
(
y_f32
);
}
};
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// host code use higher accuracy "exp" and "div"
// gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function
struct
FastGelu
{
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
typename
Y
,
typename
X
>
CK_TILE_DEVICE
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
CK_TILE_HOST
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
const
float
c1
=
-
2.0
*
0.035677
f
;
const
float
c2
=
-
2.0
*
0.797885
f
;
const
float
u
=
x
*
(
c1
*
x
*
x
+
c2
);
const
float
emu
=
exp
(
u
);
y
=
x
/
(
1.
f
+
emu
);
}
// device code, use lower precision "__ocml_exp_f32" and "rcp"
template
<
>
CK_TILE_DEVICE
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
// const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const
float
c1
=
-
2.0
*
0.035677
f
;
const
float
c2
=
-
2.0
*
0.797885
f
;
const
float
u
=
x
*
(
c1
*
x
*
x
+
c2
);
const
float
emu
=
__ocml_exp_f32
(
u
);
y
=
x
*
ck_tile
::
rcp
(
1.
f
+
emu
);
}
template
<
>
CK_TILE_HOST
void
operator
()
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
>
(
ck_tile
::
fp16_t
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
type_convert
<
float
>
(
x
));
y
=
type_convert
<
ck_tile
::
fp16_t
>
(
y_f
);
}
template
<
>
CK_TILE_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
>
(
ck_tile
::
fp16_t
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
type_convert
<
float
>
(
x
));
y
=
type_convert
<
ck_tile
::
fp16_t
>
(
y_f
);
}
template
<
>
CK_TILE_HOST
void
operator
()
<
ck_tile
::
fp16_t
,
float
>
(
ck_tile
::
fp16_t
&
y
,
const
float
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
x
);
y
=
type_convert
<
ck_tile
::
fp16_t
>
(
y_f
);
}
template
<
>
CK_TILE_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
float
>
(
ck_tile
::
fp16_t
&
y
,
const
float
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
x
);
y
=
type_convert
<
ck_tile
::
fp16_t
>
(
y_f
);
}
template
<
>
CK_TILE_HOST
void
operator
()
<
ck_tile
::
bf16_t
,
float
>
(
ck_tile
::
bf16_t
&
y
,
const
float
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
x
);
y
=
type_convert
<
ck_tile
::
bf16_t
>
(
y_f
);
}
template
<
>
CK_TILE_DEVICE
void
operator
()
<
ck_tile
::
bf16_t
,
float
>
(
ck_tile
::
bf16_t
&
y
,
const
float
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
x
);
y
=
type_convert
<
ck_tile
::
bf16_t
>
(
y_f
);
}
template
<
>
CK_TILE_DEVICE
void
operator
()
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
>
(
ck_tile
::
bf16_t
&
y
,
const
ck_tile
::
bf16_t
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
type_convert
<
float
>
(
x
));
y
=
type_convert
<
ck_tile
::
bf16_t
>
(
y_f
);
}
template
<
>
CK_TILE_HOST
void
operator
()
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
>
(
ck_tile
::
bf16_t
&
y
,
const
ck_tile
::
bf16_t
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
type_convert
<
float
>
(
x
));
y
=
type_convert
<
ck_tile
::
bf16_t
>
(
y_f
);
}
};
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2)))
struct
Gelu
{
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST_DEVICE
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
0.5
f
*
x
*
(
1.
f
+
erf
(
float
(
0.70710678118
f
*
x
)));
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
>
(
ck_tile
::
fp16_t
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
y
=
ck_tile
::
fp16_t
(
0.5
)
*
x
*
(
ck_tile
::
fp16_t
(
1
)
+
ck_tile
::
fp16_t
(
erf
(
float
(
0.70710678118
f
*
x
))));
}
};
struct
Sigmoid
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
one
/
(
one
+
ck_tile
::
exp
(
-
x
));
};
};
struct
Silu
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
x
*
(
one
/
(
one
+
ck_tile
::
exp
(
-
x
)));
};
};
struct
TanH
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
tanh
(
x
);
};
};
struct
ACos
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
acos
(
x
);
};
};
struct
Neg
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
neg
(
x
);
};
};
struct
ATan
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
atan
(
x
);
};
};
struct
Sin
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
sin
(
x
);
};
};
struct
ASinH
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
asinh
(
x
);
};
};
struct
Cos
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
cos
(
x
);
};
};
struct
ACosH
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
acosh
(
x
);
};
};
struct
Tan
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
tan
(
x
);
};
};
struct
ATanH
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
atanh
(
x
);
};
};
struct
SinH
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
sinh
(
x
);
};
};
struct
Ceil
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
ceil
(
x
);
};
};
struct
Exp
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
exp
(
x
);
};
};
struct
CosH
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
cosh
(
x
);
};
};
struct
Floor
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
floor
(
x
);
};
};
struct
Log
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
log
(
x
);
};
};
struct
ASin
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
asin
(
x
);
};
};
struct
Rcp
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
rcp
(
x
);
};
};
struct
Swish
{
Swish
(
float
beta
=
1.0
f
)
:
beta_
(
beta
)
{}
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST_DEVICE
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
X
,
float
>
||
std
::
is_same_v
<
X
,
double
>
||
std
::
is_same_v
<
X
,
ck_tile
::
fp16_t
>
,
"Data type is not supported by this operation!"
);
static_assert
(
std
::
is_same_v
<
Y
,
float
>
||
std
::
is_same_v
<
Y
,
double
>
||
std
::
is_same_v
<
Y
,
ck_tile
::
fp16_t
>
,
"Data type is not supported by this operation!"
);
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
ck_tile
::
exp
(
bx
)));
};
const
float
beta_
;
};
struct
SoftRelu
{
SoftRelu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
ck_tile
::
log
(
one
+
ck_tile
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
const
float
alpha_
;
};
struct
Power
{
Power
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
,
float
gamma
=
2.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
),
gamma_
(
gamma
){};
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
T
casted_gamma
=
type_convert
<
T
>
(
gamma_
);
T
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck_tile
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
const
float
alpha_
;
const
float
beta_
;
const
float
gamma_
;
};
struct
ClippedRelu
{
ClippedRelu
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
){};
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
y
=
ck_tile
::
min
(
casted_beta
,
ck_tile
::
max
(
casted_alpha
,
x
));
}
const
float
alpha_
;
const
float
beta_
;
};
struct
LeakyRelu
{
LeakyRelu
(
float
alpha
=
0.01
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
const
float
alpha_
;
};
struct
Elu
{
Elu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck_tile
::
expm1
(
x
);
}
const
float
alpha_
;
};
struct
Logistic
{
Logistic
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck_tile
::
exp
(
-
x
)
*
casted_alpha
);
}
const
float
alpha_
;
};
struct
ConvInvscale
{
CK_TILE_HOST_DEVICE
ConvInvscale
(
float
scale_in
=
1.
f
,
float
scale_wei
=
1.
f
,
float
scale_out
=
1.
f
)
:
scale_in_
(
scale_in
),
scale_wei_
(
scale_wei
),
scale_out_
(
scale_out
)
{
}
template
<
typename
E
,
typename
C
>
CK_TILE_HOST_DEVICE
void
operator
()(
E
&
e
,
const
C
&
c
)
const
;
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp8_t
,
float
>
(
ck_tile
::
fp8_t
&
e
,
const
float
&
c
)
const
{
e
=
type_convert
<
ck_tile
::
fp8_t
>
(
c
/
scale_in_
/
scale_wei_
/
scale_out_
);
};
float
scale_in_
;
float
scale_wei_
;
float
scale_out_
;
};
struct
ConvScale
{
CK_TILE_HOST_DEVICE
ConvScale
(
float
scale_in
=
1.
f
,
float
scale_wei
=
1.
f
,
float
scale_out
=
1.
f
)
:
scale_in_
(
scale_in
),
scale_wei_
(
scale_wei
),
scale_out_
(
scale_out
)
{
}
template
<
typename
E
,
typename
C
>
CK_TILE_HOST_DEVICE
void
operator
()(
E
&
e
,
const
C
&
c
)
const
;
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp8_t
,
float
>
(
ck_tile
::
fp8_t
&
e
,
const
float
&
c
)
const
{
e
=
type_convert
<
ck_tile
::
fp8_t
>
(
c
*
scale_in_
*
scale_wei_
*
scale_out_
);
};
float
scale_in_
;
float
scale_wei_
;
float
scale_out_
;
};
struct
ConvScaleRelu
{
CK_TILE_HOST_DEVICE
ConvScaleRelu
(
float
scale_in
=
1.
f
,
float
scale_wei
=
1.
f
,
float
scale_out
=
1.
f
)
:
scale_in_
(
scale_in
),
scale_wei_
(
scale_wei
),
scale_out_
(
scale_out
)
{
}
template
<
typename
E
,
typename
C
>
CK_TILE_HOST_DEVICE
void
operator
()(
E
&
e
,
const
C
&
c
)
const
;
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp8_t
,
float
>
(
ck_tile
::
fp8_t
&
e
,
const
float
&
c
)
const
{
float
x
;
Relu
{}.
template
operator
()
<
float
>(
x
,
c
*
scale_in_
*
scale_wei_
);
e
=
type_convert
<
ck_tile
::
fp8_t
>
(
x
*
scale_out_
);
};
float
scale_in_
;
float
scale_wei_
;
float
scale_out_
;
};
template
<
typename
DstType
,
typename
SrcType
>
struct
Cast
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
DstType
&
y
,
const
SrcType
&
x
)
const
{
y
=
ck_tile
::
type_convert
<
DstType
>
(
x
);
};
};
// support fastconvert of int8 to fp16
#if 0
template <typename InputDataType, typename OutputDataType, index_t RegPackNumber>
struct FastNumericArrayConverter
{
};
template <>
struct FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, 4>
{
using InputArray = vector_type<uint8_t, 4>;
using OutputArray = vector_type<ck_tile::fp16_t, 4>;
CK_TILE_DEVICE static OutputArray convert(InputArray const& Input)
{
OutputArray Output;
uint32_t* half_2 = reinterpret_cast<uint32_t*>(&Output);
uint32_t const uint8_4 = reinterpret_cast<uint32_t const&>(Input);
static constexpr uint32_t byte_selector_01 = 0x05010500;
static constexpr uint32_t byte_selector_23 = 0x05030502;
static constexpr uint32_t fp16_adder = 0x64646464;
half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01);
half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
: "=v"(half_2[0])
: "v"(half_2[0]), "s"(I8s_TO_F16s_MAGIC_NUM));
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
: "=v"(half_2[1])
: "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM));
return Output;
}
CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); }
};
template <index_t N>
struct FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, N>
{
static constexpr int VEC_WIDTH = 4;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
using InputArray = vector_type<uint8_t, N>;
using OutputArray = vector_type<ck_tile::fp16_t, N>;
CK_TILE_DEVICE static OutputArray convert(InputArray const& Input)
{
FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, 4> converter;
OutputArray Output;
using Vec_InputArray = vector_type<uint8_t, 4>;
using Vec_OutputArray = vector_type<ck_tile::fp16_t, 4>;
Vec_OutputArray* half_4_ptr = reinterpret_cast<Vec_OutputArray*>(&Output);
Vec_InputArray const* uint8_4_ptr = reinterpret_cast<Vec_InputArray const*>(&Input);
static_for<0, N / VEC_WIDTH, 1>{}(
[&](auto i) { half_4_ptr[i] = converter(uint8_4_ptr[i]); });
return Output;
}
CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); }
};
#endif
}
// namespace element_wise
}
// namespace ck_tile
include/ck_tile/ops/epilogue.hpp
View file @
0475a327
...
...
@@ -5,4 +5,6 @@
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
View file @
0475a327
...
...
@@ -9,23 +9,29 @@ namespace ck_tile {
// this epilogue just store out a M*N matrix, row major
template
<
typename
AccDataType_
,
typename
ODataType_
,
bool
kPadM_
,
bool
kPadN_
>
template
<
typename
AccDataType_
,
typename
ODataType_
,
bool
kPadM_
,
bool
kPadN_
,
bool
UseRawStore_
=
true
>
struct
Default2DEpilogueProblem
{
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
UseRawStore
=
UseRawStore_
;
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
Default2DEpilogue
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
UseRawStore
=
Problem
::
UseRawStore
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
...
...
@@ -36,7 +42,7 @@ struct Default2DEpilogue
{
// TODO: this is ugly
if
constexpr
(
kPadM
||
kPadN
)
if
constexpr
(
UseRawStore
&&
(
kPadM
||
kPadN
)
)
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
buffer_store_fence
();
...
...
include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp
0 → 100644
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
namespace
ck_tile
{
template
<
bool
kPadM_
,
bool
kPadN_
,
bool
UseSmoothInputScale_
,
bool
UseRawStore_
=
true
,
bool
UseMax3_
=
false
>
struct
DynamicQuantEpilogueTraits
{
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
UseSmoothInputScale
=
UseSmoothInputScale_
;
static
constexpr
bool
UseRawStore
=
UseRawStore_
;
static
constexpr
bool
UseMax3
=
UseMax3_
;
};
// this epilogue just store out a M*N matrix, row major
template
<
typename
AccDataType_
,
typename
XScaleDataType_
,
typename
YScaleDataType_
,
typename
ODataType_
,
typename
BlockShape_
,
typename
Traits_
>
struct
DynamicQuantEpilogueProblem
{
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
XScaleDataType
=
remove_cvref_t
<
XScaleDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
// can consum generic 2d shape
using
Traits
=
remove_cvref_t
<
Traits_
>
;
};
// TODO: we should put descriptor creation function into policy
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
DynamicQuantEpilogue
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
XScaleDataType
=
remove_cvref_t
<
typename
Problem
::
XScaleDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
BlockShape
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
static
constexpr
bool
kPadM
=
Problem
::
Traits
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
UseRawStore
=
Problem
::
Traits
::
UseRawStore
;
static
constexpr
bool
UseMax3
=
Problem
::
Traits
::
UseMax3
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2d
()
{
using
P_
=
BlockReduce2dProblem
<
AccDataType
,
AccDataType
,
BlockShape
>
;
return
BlockReduce2d
<
P_
>
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dSync
()
{
using
P_
=
BlockReduce2dProblem
<
AccDataType
,
AccDataType
,
BlockShape
>
;
return
BlockReduce2dSync
<
P_
>
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dCrossWarpSync
()
{
using
P_
=
BlockReduce2dProblem
<
AccDataType
,
AccDataType
,
BlockShape
>
;
return
BlockReduce2dCrossWarpSync
<
P_
>
{};
}
CK_TILE_DEVICE
static
constexpr
auto
MakeSmoothInputScaleTileDistribution
()
{
using
S
=
BlockShape
;
#if 0
// don't remove this
// Note that if we set encoding purposely like this, you will result in compile fail
// TODO: x_scale create local-scratch to accept arbitrary acc input (with same length)
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>,
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<0, 1, 1>,
sequence<0, 0, 3>>{});
#else
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
>
,
tuple
<
sequence
<
S
::
Repeat_N
,
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
,
S
::
Vector_N
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
3
>>
{});
#endif
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
auto
reduce_crosswarp_sync
=
GetBlockReduce2dCrossWarpSync
();
return
reduce_crosswarp_sync
.
GetSmemSize
();
}
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
template
<
typename
ODramWindowTmp
,
typename
XScaleWindow
,
typename
YScaleWindow
,
typename
OAccTile
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
const
XScaleWindow
&
x_scale_window_
,
YScaleWindow
&
y_scale_window
,
const
OAccTile
&
o_acc_tile
,
void
*
smem
)
{
auto
reduce
=
GetBlockReduce2d
();
auto
reduce_sync
=
GetBlockReduce2dSync
();
auto
reduce_crosswarp_sync
=
GetBlockReduce2dCrossWarpSync
();
const
auto
x_scale_window
=
make_tile_window
(
x_scale_window_
,
MakeSmoothInputScaleTileDistribution
());
auto
x_scale
=
load_tile
(
x_scale_window
);
auto
o_acc_tmp
=
o_acc_tile
;
sweep_tile
(
o_acc_tmp
,
[
&
](
auto
idx
)
{
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
xs_
=
type_convert
<
AccDataType
>
(
x_scale
[
j_idx
]);
o_acc_tmp
(
idx
)
=
o_acc_tmp
(
idx
)
*
xs_
;
});
const
auto
f_absmax
=
[](
auto
acc_
,
auto
v_0_
)
{
return
max
(
acc_
,
abs
(
v_0_
));
};
auto
row_absmax
=
[
&
]()
{
constexpr
auto
y_size_per_row
=
OAccTile
{}.
get_tile_distribution
().
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
1
>
{});
if
constexpr
(
UseMax3
&&
std
::
is_same_v
<
AccDataType
,
float
>
&&
y_size_per_row
%
2
==
0
)
{
// fast max3+abs implementation
const
auto
f_max3
=
[](
auto
acc_
,
auto
v_0_
,
auto
v_1_
)
{
float
rtn
;
asm
volatile
(
"v_max3_f32 %0, %1, abs(%2), abs(%3)"
:
"=v"
(
rtn
)
:
"v"
(
acc_
),
"v"
(
v_0_
),
"v"
(
v_1_
));
return
rtn
;
};
return
reduce
(
o_acc_tmp
,
type_convert
<
AccDataType
>
(
0
),
f_max3
,
sequence
<
1
,
2
>
{});
}
else
{
return
reduce
(
o_acc_tmp
,
type_convert
<
AccDataType
>
(
0
),
f_absmax
);
}
}();
reduce_sync
(
row_absmax
,
f_absmax
);
reduce_crosswarp_sync
(
row_absmax
,
smem
,
f_absmax
);
// here y_scale is Acc TYpe, need convert to YScale type later
auto
y_scale
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
v_
/
type_convert
<
AccDataType
>
(
numeric
<
ODataType
>::
max
());
},
row_absmax
);
store_tile
(
y_scale_window
,
cast_tile
<
YScaleDataType
>
(
y_scale
));
sweep_tile
(
o_acc_tmp
,
[
&
](
auto
idx
)
{
constexpr
auto
row_id
=
make_tuple
(
idx
[
number
<
0
>
{}]);
o_acc_tmp
(
idx
)
=
o_acc_tmp
[
idx
]
/
y_scale
(
row_id
);
});
// TODO: this is ugly
if
constexpr
(
UseRawStore
&&
(
kPadM
||
kPadN
))
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tmp
));
buffer_store_fence
();
}
else
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tmp
));
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha.hpp
View file @
0475a327
...
...
@@ -43,4 +43,5 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
0475a327
...
...
@@ -82,10 +82,10 @@ struct FmhaFwdKernel
if
(
kPadHeadDimV
)
n
+=
"dv"
;
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_fwd_d"
)
+
_TS_
(
bfs
::
k
K0BlockLength
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
_SS_
(
"fmha_fwd_d"
)
+
_TS_
(
bfs
::
k
QKHeaddim
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
_SS_
(
TilePartitioner
::
name
)
+
"_"
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kN1
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
k
K0BlockLength
)
+
"_"
+
_TS_
(
bfs
::
kN1
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
k
QKHeaddim
)
+
"_"
+
"r"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
...
...
@@ -657,7 +657,7 @@ struct FmhaFwdKernel
{
return
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
k
K0BlockLength
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
k
SubQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
else
...
...
@@ -724,7 +724,7 @@ struct FmhaFwdKernel
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
k
K0BlockLength
>
{});
number
<
FmhaPipeline
::
k
SubQKHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
0475a327
...
...
@@ -78,10 +78,10 @@ struct FmhaFwdSplitKVKernel
if
(
kPadHeadDimV
)
n
+=
"dv"
;
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_fwd_splitkv_d"
)
+
_TS_
(
bfs
::
k
K0BlockLength
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
_SS_
(
"fmha_fwd_splitkv_d"
)
+
_TS_
(
bfs
::
k
QKHeaddim
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kN1
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
k
K0BlockLength
)
+
"_"
+
_TS_
(
bfs
::
kN1
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
k
QKHeaddim
)
+
"_"
+
"r"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
...
...
@@ -586,7 +586,7 @@ struct FmhaFwdSplitKVKernel
{
return
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
k
K0BlockLength
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
k
SubQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
else
...
...
@@ -735,7 +735,7 @@ struct FmhaFwdSplitKVKernel
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
k
K0BlockLength
>
{});
number
<
FmhaPipeline
::
k
SubQKHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
View file @
0475a327
...
...
@@ -34,12 +34,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kSubQKHeaddim
=
BlockFmhaShape
::
kSubQKHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
...
...
@@ -75,22 +76,22 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return
Problem
::
kBlockPerCu
;
else
{
if
constexpr
(
k
K0BlockLength
<=
32
)
if
constexpr
(
k
QKHeaddim
<=
32
)
{
return
2
;
}
else
if
constexpr
(
k
K0BlockLength
<=
64
)
else
if
constexpr
(
k
QKHeaddim
<=
64
)
{
return
3
;
}
else
if
constexpr
(
k
K0BlockLength
<=
128
)
else
if
constexpr
(
k
QKHeaddim
<=
128
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
k
K0BlockLength
<=
256
)
else
if
constexpr
(
k
QKHeaddim
<=
256
)
{
return
1
;
}
...
...
@@ -270,7 +271,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
// prefetch K tile
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
k
K0BlockLength
/
kK0
;
constexpr
index_t
k0_loops
=
k
QKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
2
<=
k0_loops
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
View file @
0475a327
...
...
@@ -37,12 +37,13 @@ struct BlockFmhaPipelineQRKSVS
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kSubQKHeaddim
=
BlockFmhaShape
::
kSubQKHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
...
...
@@ -76,22 +77,22 @@ struct BlockFmhaPipelineQRKSVS
return
Problem
::
kBlockPerCu
;
else
{
if
constexpr
(
k
K0BlockLength
<=
32
)
if
constexpr
(
k
QKHeaddim
<=
32
)
{
return
2
;
}
else
if
constexpr
(
k
K0BlockLength
<=
64
)
else
if
constexpr
(
k
QKHeaddim
<=
64
)
{
return
3
;
}
else
if
constexpr
(
k
K0BlockLength
<=
128
)
else
if
constexpr
(
k
QKHeaddim
<=
128
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
k
K0BlockLength
<=
256
)
else
if
constexpr
(
k
QKHeaddim
<=
256
)
{
return
1
;
}
...
...
@@ -261,7 +262,7 @@ struct BlockFmhaPipelineQRKSVS
// prefetch K tile
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
k
K0BlockLength
/
kK0
;
constexpr
index_t
k0_loops
=
k
QKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
2
<=
k0_loops
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
0475a327
...
...
@@ -38,12 +38,13 @@ struct BlockFmhaPipelineQRKSVSAsync
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kSubQKHeaddim
=
BlockFmhaShape
::
kSubQKHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
...
...
@@ -87,7 +88,7 @@ struct BlockFmhaPipelineQRKSVSAsync
return
1
;
}
if
constexpr
(
k
K0BlockLength
<=
32
)
if
constexpr
(
k
QKHeaddim
<=
32
)
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
&&
FmhaMask
::
IsMasking
)
...
...
@@ -95,21 +96,21 @@ struct BlockFmhaPipelineQRKSVSAsync
else
return
2
;
}
else
if
constexpr
(
k
K0BlockLength
<=
64
)
else
if
constexpr
(
k
QKHeaddim
<=
64
)
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
2
;
else
return
3
;
}
else
if
constexpr
(
k
K0BlockLength
<=
128
)
else
if
constexpr
(
k
QKHeaddim
<=
128
)
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
k
K0BlockLength
<=
256
)
else
if
constexpr
(
k
QKHeaddim
<=
256
)
{
return
1
;
}
...
...
@@ -334,12 +335,12 @@ struct BlockFmhaPipelineQRKSVSAsync
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_fence
(
k_dram_window
.
get_num_access
(),
q
.
get_thread_buffer
());
buffer_load_fence
(
k_dram_window
.
get_num_
of_
access
(),
q
.
get_thread_buffer
());
(
void
)
q_element_func
;
// ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
k
K0BlockLength
/
kK0
;
constexpr
index_t
k0_loops
=
k
QKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
1
<=
k0_loops
);
...
...
@@ -359,7 +360,7 @@ struct BlockFmhaPipelineQRKSVSAsync
if
constexpr
(
i_k0
<
k0_loops
-
1
)
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
async_load_fence
(
k_dram_window
.
get_num_access
());
async_load_fence
(
k_dram_window
.
get_num_
of_
access
());
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
gemm_0
(
s_acc
,
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
View file @
0475a327
...
...
@@ -36,12 +36,12 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
k
K0BlockLength
=
BlockFmhaShape
::
k
K0BlockLength
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
k
QKHeaddim
=
BlockFmhaShape
::
k
QKHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
...
...
@@ -75,22 +75,22 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
return
Problem
::
kBlockPerCu
;
else
{
if
constexpr
(
k
K0BlockLength
<=
32
)
if
constexpr
(
k
QKHeaddim
<=
32
)
{
return
2
;
}
else
if
constexpr
(
k
K0BlockLength
<=
64
)
else
if
constexpr
(
k
QKHeaddim
<=
64
)
{
return
3
;
}
else
if
constexpr
(
k
K0BlockLength
<=
128
)
else
if
constexpr
(
k
QKHeaddim
<=
128
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
k
K0BlockLength
<=
256
)
else
if
constexpr
(
k
QKHeaddim
<=
256
)
{
return
1
;
}
...
...
@@ -232,7 +232,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
// prefetch K tile
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
k
K0BlockLength
/
kK0
;
constexpr
index_t
k0_loops
=
k
QKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
2
<=
k0_loops
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
View file @
0475a327
...
...
@@ -36,12 +36,13 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kSubQKHeaddim
=
BlockFmhaShape
::
kSubQKHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
...
...
@@ -56,22 +57,22 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
return
Problem
::
kBlockPerCu
;
else
{
if
constexpr
(
k
K0BlockLength
<=
32
)
if
constexpr
(
k
QKHeaddim
<=
32
)
{
return
2
;
}
else
if
constexpr
(
k
K0BlockLength
<=
64
)
else
if
constexpr
(
k
QKHeaddim
<=
64
)
{
return
3
;
}
else
if
constexpr
(
k
K0BlockLength
<=
128
)
else
if
constexpr
(
k
QKHeaddim
<=
128
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
k
K0BlockLength
<=
256
)
else
if
constexpr
(
k
QKHeaddim
<=
256
)
{
return
1
;
}
...
...
@@ -235,7 +236,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
// prefetch K tile
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
k
K0BlockLength
/
kK0
;
constexpr
index_t
k0_loops
=
k
QKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
2
<=
k0_loops
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
0475a327
...
...
@@ -55,7 +55,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0BlockLength
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
SubQKHeaddim
;
constexpr
index_t
K2
=
WG
::
kK
/
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
K1
=
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
...
...
@@ -323,6 +323,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template
<
>
struct
LdsBufferSequence
<
3
,
3
,
3
,
3
>
{
using
type
=
sequence
<
1
,
2
,
0
,
1
,
2
,
0
>
;
};
template
<
>
struct
LdsBufferSequence
<
3
,
3
,
3
,
4
>
{
using
type
=
sequence
<
1
,
2
,
0
,
0
,
1
,
2
,
0
>
;
};
template
<
>
struct
LdsBufferSequence
<
3
,
3
,
2
,
2
>
{
using
type
=
sequence
<
1
,
2
,
1
,
0
>
;};
// clang-format on
...
...
@@ -332,12 +335,12 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
{
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
constexpr
index_t
k
K0BlockLength
=
BlockFmhaShape
::
k
K0BlockLength
;
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
constexpr
index_t
k
QKHeaddim
=
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
k0_loops
=
k
K0BlockLength
/
kK0
;
constexpr
index_t
k0_loops
=
k
QKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
return
typename
LdsBufferSequence
<
NumPrefetchK
,
NumPrefetchV
,
k0_loops
,
k1_loops
>::
type
{};
...
...
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -7,6 +7,20 @@
namespace
ck_tile
{
static
CK_TILE_HOST_DEVICE
constexpr
index_t
ceil_to_qualified_tile_length
(
index_t
len
)
{
if
(
len
==
96
)
return
128
;
if
(
len
==
160
)
return
256
;
// only length of 96, 160 and power-of-two is supported
if
(
!
(
len
&
(
len
-
1
)))
return
len
;
return
0
;
};
template
<
typename
BlockTile_
,
// sequence<...
typename
Gemm0BlockWarps_
,
typename
Gemm0WarpTile_
,
...
...
@@ -36,10 +50,12 @@ struct TileFmhaShape
static
constexpr
index_t
kK0
=
BlockTile
::
at
(
number
<
2
>
{});
// tile size along qk gemm unroll
static
constexpr
index_t
kN1
=
BlockTile
::
at
(
number
<
3
>
{});
// tile size along v head_dim
static
constexpr
index_t
kK1
=
BlockTile
::
at
(
number
<
4
>
{});
// tile size along kv gemm unroll
static
constexpr
index_t
k
K0BlockLength
=
static
constexpr
index_t
k
QKHeaddim
=
BlockTile
::
at
(
number
<
5
>
{});
// total length of K0, used for pipeline that need load Q at
// once (or repeately load Q as a whole tile)
static_assert
(
kK0BlockLength
%
kK0
==
0
,
"kK0BlockLength should be divisible by kK0"
);
static_assert
(
kQKHeaddim
%
kK0
==
0
,
"kQKHeaddim should be divisible by kK0"
);
static
constexpr
index_t
kSubQKHeaddim
=
ceil_to_qualified_tile_length
(
kQKHeaddim
);
// v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
static
constexpr
bool
IsVLayoutRowMajor
=
IsVLayoutRowMajor_
;
...
...
include/ck_tile/ops/gemm.hpp
View file @
0475a327
...
...
@@ -24,6 +24,8 @@
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
...
...
@@ -37,4 +39,5 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp
View file @
0475a327
...
...
@@ -32,7 +32,7 @@ struct BlockGemmARegBGmemCRegV1
BlockGemmProblem
<
ADataType
,
BDataType
,
CDataType
,
kBlockSize
,
BlockGemmShape
>
,
BlockGemmARegBGmemCRegV1DefaultPolicy
>
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
{
return
sizeof
(
BDataType
)
*
Policy
::
template
MakeBSmemBlockDescriptor
<
Problem
>().
get_element_space_size
();
...
...
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
View file @
0475a327
...
...
@@ -24,19 +24,19 @@ struct BlockGemmASmemBSmemCRegV1
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
// C += A * B
template
<
typename
CBlockTensor
,
typename
ABlockWindow
Tmp
,
typename
BBlockWindow
Tmp
>
template
<
typename
CBlockTensor
,
typename
ABlockWindow
,
typename
BBlockWindow
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ABlockWindow
Tmp
&
a_block_window
_tmp
,
const
BBlockWindow
Tmp
&
b_block_window
_tmp
)
const
const
ABlockWindow
&
a_block_window
,
const
BBlockWindow
&
b_block_window
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
typename
ABlockWindow
Tmp
::
DataType
>
&&
std
::
is_same_v
<
BDataType
,
typename
BBlockWindow
Tmp
::
DataType
>
&&
static_assert
(
std
::
is_same_v
<
ADataType
,
typename
ABlockWindow
::
DataType
>
&&
std
::
is_same_v
<
BDataType
,
typename
BBlockWindow
::
DataType
>
&&
std
::
is_same_v
<
CDataType
,
typename
CBlockTensor
::
DataType
>
,
"wrong!"
);
constexpr
index_t
MPerBlock
=
ABlockWindow
Tmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
NPerBlock
=
BBlockWindow
Tmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
KPerBlock
=
ABlockWindow
Tmp
{}.
get_window_lengths
()[
number
<
1
>
{}];
constexpr
index_t
MPerBlock
=
ABlockWindow
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
NPerBlock
=
BBlockWindow
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
KPerBlock
=
ABlockWindow
{}.
get_window_lengths
()[
number
<
1
>
{}];
static_assert
(
MPerBlock
==
BlockGemmShape
::
kM
&&
NPerBlock
==
BlockGemmShape
::
kN
&&
KPerBlock
==
BlockGemmShape
::
kK
,
...
...
@@ -62,9 +62,9 @@ struct BlockGemmASmemBSmemCRegV1
// construct A-warp-window
auto
a_warp_window_tmp
=
make_tile_window
(
a_block_window
_tmp
.
get_bottom_tensor_view
(),
a_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WG
::
kM
>
{},
number
<
WG
::
kK
>
{}),
a_block_window
_tmp
.
get_window_origin
()
+
multi_index
<
2
>
{
iMWarp
*
WG
::
kM
,
0
},
a_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iMWarp
*
WG
::
kM
,
0
},
make_static_tile_distribution
(
typename
WG
::
AWarpDstrEncoding
{}));
#if 0 // FIXME: using array will cause register spill
...
...
@@ -97,9 +97,9 @@ struct BlockGemmASmemBSmemCRegV1
// construct B-warp-window
auto
b_warp_window_tmp
=
make_tile_window
(
b_block_window
_tmp
.
get_bottom_tensor_view
(),
b_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WG
::
kN
>
{},
number
<
WG
::
kK
>
{}),
b_block_window
_tmp
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WG
::
kN
,
0
},
b_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WG
::
kN
,
0
},
make_static_tile_distribution
(
typename
WG
::
BWarpDstrEncoding
{}));
#if 0 // FIXME: using array will cause register spill
...
...
@@ -200,12 +200,12 @@ struct BlockGemmASmemBSmemCRegV1
}
// C = A * B
template
<
typename
ABlockTensorTmp
,
typename
BBlockWindow
Tmp
>
template
<
typename
ABlockTensorTmp
,
typename
BBlockWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
ABlockTensorTmp
&
a_block_tensor_tmp
,
const
BBlockWindow
Tmp
&
b_block_window
_tmp
)
const
const
BBlockWindow
&
b_block_window
)
const
{
auto
c_block_tensor
=
MakeCBlockTile
();
operator
()(
c_block_tensor
,
a_block_tensor_tmp
,
b_block_window
_tmp
);
operator
()(
c_block_tensor
,
a_block_tensor_tmp
,
b_block_window
);
return
c_block_tensor
;
}
};
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
0475a327
...
...
@@ -3,12 +3,13 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
...
...
@@ -17,20 +18,19 @@ struct GemmKernel
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
kBlockSize
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
using
CAccDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
CDataType
>
;
using
CODataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmPipeline
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
CLayout
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
BlockSize
;
using
LayoutA
=
remove_cvref_t
<
typename
GemmPipeline
::
LayoutA
>
;
using
LayoutB
=
remove_cvref_t
<
typename
GemmPipeline
::
LayoutB
>
;
using
LayoutC
=
remove_cvref_t
<
typename
GemmPipeline
::
LayoutC
>
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
// using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
__host__
static
constexpr
auto
GridSize
(
index_t
M
_size
,
index_t
N
_size
,
index_t
Batch
_size
)
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
K
Batch
)
{
return
TilePartitioner
::
GridSize
(
M
_size
,
N_size
,
Batch
_size
);
return
TilePartitioner
::
GridSize
(
M
,
N
,
K
Batch
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
...
...
@@ -40,34 +40,30 @@ struct GemmKernel
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
float
epsilon
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
N
;
ck_tile
::
index_t
K
;
ck_tile
::
index_t
stride_A
;
ck_tile
::
index_t
stride_B
;
ck_tile
::
index_t
stride_C
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
};
CK_TILE_HOST
static
constexpr
GemmCommonKargs
MakeKargs
(
const
void
*
a_ptr
,
const
void
*
b_ptr
,
void
*
c_ptr
,
float
epsilon
,
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
,
ck_tile
::
index_t
K
,
ck_tile
::
index_t
stride_A
,
ck_tile
::
index_t
stride_B
,
ck_tile
::
index_t
stride_C
)
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_A
,
index_t
stride_B
,
index_t
stride_C
)
{
return
GemmCommonKargs
{
a_ptr
,
b_ptr
,
c_ptr
,
epsilon
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
};
return
GemmCommonKargs
{
a_ptr
,
b_ptr
,
c_ptr
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
};
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
ck_tile
::
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_DEVICE
void
operator
()(
GemmCommonKargs
kargs
)
const
...
...
@@ -78,13 +74,13 @@ struct GemmKernel
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
// Convert pointers to tensor views
auto
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
Layout
A
,
tensor_layout
::
gemm
::
Column
Major
>
)
if
constexpr
(
std
::
is_same_v
<
A
Layout
,
tensor_layout
::
gemm
::
Row
Major
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
GemmPipeline
::
Alignment
A
>
{},
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
VectorSize
A
>
{},
number
<
1
>
{});
}
else
...
...
@@ -92,29 +88,29 @@ struct GemmKernel
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
AlignmentA
>
{},
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
auto
b_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
Layout
B
,
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
B
Layout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_B
),
number
<
GemmPipeline
::
AlignmentB
>
{},
number
<
1
>
{},
number
<
1
>
{});
}
else
{
// Default NK layout
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
Alignment
B
>
{},
number
<
GemmPipeline
::
VectorSize
B
>
{},
number
<
1
>
{});
}
}();
...
...
@@ -122,10 +118,12 @@ struct GemmKernel
auto
a_pad_view
=
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
0
,
GemmPipeline
::
kPadA
?
1
:
0
>
{});
// somehow clang-format is splitting below line into multiple.
// clang-format off
sequence
<
false
,
GemmPipeline
::
kPadA
>
{});
// clang-format on
auto
AB
lock
W
indow
=
make_tile_window
(
auto
a_b
lock
_w
indow
=
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_m
,
0
});
...
...
@@ -133,10 +131,11 @@ struct GemmKernel
auto
b_pad_view
=
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
0
,
GemmPipeline
::
kPadB
?
1
:
0
>
{});
// clang-format off
sequence
<
false
,
GemmPipeline
::
kPadB
>
{});
// clang-format on
auto
BB
lock
W
indow
=
make_tile_window
(
auto
b_b
lock
_w
indow
=
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_n
,
0
});
...
...
@@ -144,20 +143,21 @@ struct GemmKernel
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
const
index_t
num_loop
=
(
kargs
.
K
+
TilePartitioner
::
kK
-
1
)
/
TilePartitioner
::
kK
;
auto
acc
=
GemmPipeline
{}(
ABlockWindow
,
BBlockWindow
,
num_loop
,
smem_ptr
);
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
CODataType
*
c_start
=
static_cast
<
CODataType
*>
(
kargs
.
c_ptr
);
// Run GEMM cooperatively by whole wokrgroup.
auto
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
auto
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
Layout
C
,
tensor_layout
::
gemm
::
Column
Major
>
)
if
constexpr
(
std
::
is_same_v
<
C
Layout
,
tensor_layout
::
gemm
::
Row
Major
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
number
<
GemmPipeline
::
Alignment
C
>
{},
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
GemmPipeline
::
VectorSize
C
>
{},
number
<
1
>
{});
}
else
...
...
@@ -165,8 +165,8 @@ struct GemmKernel
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
GemmPipeline
::
AlignmentC
>
{},
make_tuple
(
1
,
kargs
.
stride_C
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
...
...
@@ -174,14 +174,15 @@ struct GemmKernel
auto
c_pad_view
=
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
0
,
GemmPipeline
::
kPadC
?
1
:
0
>
{});
auto
CBlockWindow_pad
=
make_tile_window
(
// clang-format off
sequence
<
false
,
GemmPipeline
::
kPadC
>
{});
// clang-format on
auto
c_block_window
=
make_tile_window
(
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
EpiloguePipeline
{}(
CB
lock
W
indow
_pad
,
acc
);
EpiloguePipeline
{}(
c_b
lock
_w
indow
,
c_block_tile
);
}
};
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment