Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
17ed368f
Unverified
Commit
17ed368f
authored
Jun 17, 2024
by
carlushuang
Committed by
GitHub
Jun 17, 2024
Browse files
[CK_TILE][FA] using pk f16_f32 (#1343)
* [CK_TILE][FA] using pk f16_f32 * correct a error
parent
e0210316
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
60 additions
and
8 deletions
+60
-8
include/ck_tile/core/arch/arch.hpp
include/ck_tile/core/arch/arch.hpp
+7
-4
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+4
-0
include/ck_tile/core/tensor/tile_elementwise.hpp
include/ck_tile/core/tensor/tile_elementwise.hpp
+41
-2
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+8
-2
No files found.
include/ck_tile/core/arch/arch.hpp
View file @
17ed368f
...
...
@@ -61,10 +61,13 @@ CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
CK_TILE_DEVICE
void
block_sync_lds
()
{
#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
asm
volatile
(
"\
s_waitcnt lgkmcnt(0)
\n
\
s_barrier \
"
::
);
// asm volatile("\
// s_waitcnt lgkmcnt(0) \n \
// s_barrier \
// " ::);
__builtin_amdgcn_s_waitcnt
(
0xc07f
);
__builtin_amdgcn_s_barrier
();
#else
__syncthreads
();
#endif
...
...
include/ck_tile/core/config.hpp
View file @
17ed368f
...
...
@@ -167,6 +167,10 @@
#define CK_TILE_USE_SUBDWORD_TILE_CAST 0
#endif
#ifndef CK_TILE_USE_PK_FP16_TILE_CAST
#define CK_TILE_USE_PK_FP16_TILE_CAST 0
#endif
// TODO: better solve this inside compiler
#ifndef CK_TILE_FMHA_FWD_FAST_EXP2
#define CK_TILE_FMHA_FWD_FAST_EXP2 0
...
...
include/ck_tile/core/tensor/tile_elementwise.hpp
View file @
17ed368f
...
...
@@ -110,7 +110,7 @@ CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor)
namespace
impl
{
// TODO: this is ugly
template
<
typename
OutDataType
,
typename
InTensor
>
CK_TILE_DEVICE
auto
cast_tile_pk_fp8
x4
(
const
InTensor
&
in_dstr_tensors
)
CK_TILE_DEVICE
auto
cast_tile_pk_fp8
_fp32
(
const
InTensor
&
in_dstr_tensors
)
{
#if defined(__gfx94__)
// This API is designed to use the _pk_ serious of function
...
...
@@ -156,6 +156,37 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors)
#endif
}
template
<
typename
OutDataType
,
typename
InTensor
>
CK_TILE_DEVICE
auto
cast_tile_pk_fp16_fp32
(
const
InTensor
&
in_dstr_tensors
)
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)
// This API is designed to use the _pk_ serious of function
constexpr
auto
in_tile_dstr
=
InTensor
::
get_tile_distribution
();
constexpr
index_t
thread_buffer_size
=
InTensor
::
get_thread_buffer_size
();
static_assert
(
thread_buffer_size
%
2
==
0
);
constexpr
index_t
thread_buffer_size_pk
=
thread_buffer_size
/
2
;
auto
out_dstr_tensor
=
make_static_distributed_tensor
<
OutDataType
>
(
in_tile_dstr
);
// TODO: this is rtz cvt, need be very careful
for
(
index_t
i
=
0
;
i
<
thread_buffer_size_pk
;
i
++
)
{
auto
o
=
__builtin_amdgcn_cvt_pkrtz
(
in_dstr_tensors
.
get_thread_buffer
()[
2
*
i
+
0
],
in_dstr_tensors
.
get_thread_buffer
()[
2
*
i
+
1
]);
out_dstr_tensor
.
get_thread_buffer
().
at
(
2
*
i
+
0
)
=
o
.
x
;
out_dstr_tensor
.
get_thread_buffer
().
at
(
2
*
i
+
1
)
=
o
.
y
;
}
return
out_dstr_tensor
;
#else
// fallback
return
tile_elementwise_in
(
type_convert
<
OutDataType
,
typename
InTensor
::
DataType
>
,
in_dstr_tensors
);
#endif
}
#if CK_TILE_USE_SUBDWORD_TILE_CAST
// this function assume either src or dst (or both) date type is under 1 dword
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
...
...
@@ -229,8 +260,16 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
float
>
&&
(
SrcTensor
::
get_thread_buffer_size
()
%
4
==
0
))
{
return
impl
::
cast_tile_pk_fp8
x4
<
DstType
,
SrcTensor
>
(
src_tensor
);
return
impl
::
cast_tile_pk_fp8
_fp32
<
DstType
,
SrcTensor
>
(
src_tensor
);
}
#if CK_TILE_USE_PK_FP16_TILE_CAST
else
if
constexpr
(
std
::
is_same_v
<
DstType
,
fp16_t
>
&&
std
::
is_same_v
<
typename
SrcTensor
::
DataType
,
float
>
&&
(
SrcTensor
::
get_thread_buffer_size
()
%
2
==
0
))
{
return
impl
::
cast_tile_pk_fp16_fp32
<
DstType
,
SrcTensor
>
(
src_tensor
);
}
#endif
#if CK_TILE_USE_SUBDWORD_TILE_CAST
else
if
constexpr
(
sizeof
(
DstType
)
<
4
||
sizeof
(
typename
SrcTensor
::
DataType
)
<
4
)
{
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
17ed368f
...
...
@@ -578,8 +578,14 @@ struct BlockFmhaPipelineQRKSVSAsync
randval_dram_window
);
}
const
auto
p
=
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
const
auto
p
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
PDataType
,
fp16_t
>
)
return
impl
::
cast_tile_pk_fp16_fp32
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
else
return
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
}();
// STAGE 3, KV gemm
if
constexpr
(
k1_loops
>
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment