Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3c305fee
Commit
3c305fee
authored
Oct 30, 2024
by
carlushuang
Browse files
support fused dynamic-quant
parent
7fb9b2b6
Changes
22
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
206 additions
and
45 deletions
+206
-45
example/ck_tile/02_layernorm2d/generate.py
example/ck_tile/02_layernorm2d/generate.py
+1
-1
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
+7
-5
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
+1
-0
include/ck_tile/ops/common.hpp
include/ck_tile/ops/common.hpp
+1
-0
include/ck_tile/ops/common/generic_2d_block_shape.hpp
include/ck_tile/ops/common/generic_2d_block_shape.hpp
+77
-0
include/ck_tile/ops/elementwise.hpp
include/ck_tile/ops/elementwise.hpp
+1
-0
include/ck_tile/ops/epilogue.hpp
include/ck_tile/ops/epilogue.hpp
+1
-0
include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp
include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp
+74
-22
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+1
-0
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+1
-0
include/ck_tile/ops/image_to_column.hpp
include/ck_tile/ops/image_to_column.hpp
+1
-0
include/ck_tile/ops/layernorm2d.hpp
include/ck_tile/ops/layernorm2d.hpp
+1
-0
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
+1
-1
include/ck_tile/ops/permute.hpp
include/ck_tile/ops/permute.hpp
+1
-0
include/ck_tile/ops/reduce.hpp
include/ck_tile/ops/reduce.hpp
+1
-0
include/ck_tile/ops/reduce/block/block_reduce.hpp
include/ck_tile/ops/reduce/block/block_reduce.hpp
+13
-10
include/ck_tile/ops/reduce/block/block_reduce2d.hpp
include/ck_tile/ops/reduce/block/block_reduce2d.hpp
+20
-6
include/ck_tile/ops/rmsnorm2d.hpp
include/ck_tile/ops/rmsnorm2d.hpp
+1
-0
include/ck_tile/ops/softmax.hpp
include/ck_tile/ops/softmax.hpp
+1
-0
include/ck_tile/ops/topk.hpp
include/ck_tile/ops/topk.hpp
+1
-0
No files found.
example/ck_tile/02_layernorm2d/generate.py
View file @
3c305fee
...
@@ -195,7 +195,7 @@ float layernorm2d_fwd_(const S& s, A a)
...
@@ -195,7 +195,7 @@ float layernorm2d_fwd_(const S& s, A a)
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>;
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>;
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, YScaleDataType, YDataType,
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, YScaleDataType, YDataType,
typename Traits_::Shape,
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, false, true/*max3*/>>;
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, false, true/*max3*/>>;
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
...
...
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
View file @
3c305fee
...
@@ -203,18 +203,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -203,18 +203,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
fused_sweep
==
1
)
if
(
fused_sweep
==
1
)
{
{
auto
dquant_functor
=
[
&
](
int
m_
,
auto
o_
,
auto
acc_
)
{
auto
dquant_functor
=
[
&
](
int
m_
,
auto
&
o_
,
const
auto
&
acc_
)
{
int
N_
=
acc_
.
mDesc
.
get_lengths
()[
1
];
int
N_
=
acc_
.
mDesc
.
get_lengths
()[
1
];
ComputeDataType
absmax
=
0
;
ComputeDataType
absmax
=
static_cast
<
ComputeDataType
>
(
0
)
;
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
{
const
auto
a
=
abs
(
acc_
(
m_
,
n_
));
const
auto
a
=
ck_tile
::
abs
(
acc_
(
m_
,
n_
));
absmax
=
a
>
absmax
?
a
:
absmax
;
absmax
=
a
>
absmax
?
a
:
absmax
;
}
}
y_scale_host_ref
(
m_
)
=
absmax
/
127.0
;
// printf("cpu:absmax:%f\n", absmax);
ComputeDataType
y_scale
=
absmax
/
static_cast
<
ComputeDataType
>
(
127.0
);
y_scale_host_ref
(
m_
)
=
ck_tile
::
type_convert
<
ScaleDataType
>
(
y_scale
);
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
{
o_
(
m_
,
n_
)
=
static_cas
t
<
YDataType
>
(
acc_
(
m_
,
n_
)
/
y_scale
_host_ref
(
m_
)
);
o_
(
m_
,
n_
)
=
ck_tile
::
type_conver
t
<
YDataType
>
(
acc_
(
m_
,
n_
)
/
y_scale
);
}
}
};
};
...
...
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
View file @
3c305fee
...
@@ -9,4 +9,5 @@
...
@@ -9,4 +9,5 @@
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/common.hpp
View file @
3c305fee
...
@@ -3,4 +3,5 @@
...
@@ -3,4 +3,5 @@
#pragma once
#pragma once
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/common/generic_2d_block_shape.hpp
0 → 100644
View file @
3c305fee
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck_tile
{
/*
// clang-format off
4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector
Block_N (Warp_N * WarpPerBlock_N * Repeat_N )
+<----------------------< Repeat_N(2)>--------------------->+
| |
+<-- <WarpPerBlock_N(2)> -->+
Warp_N
+--------------+--------------+--------------+--------------+----+----------------+
Warp_M | wrap_0 | wrap_1 | | ^ ^
+--------------+--------------+ | <WarpPerBlock_M(2)> |
| wrap_2 | wrap_3 | | v
+--------------+--------------+--------------+--------------+----+ Block_M
| | |
+ + |
| | | v
+--------------+--------------+--------------+--------------+ +
each Warp-tile (e.g 16 thrd per row)
Vector_N (contiguous pixels each thrd holds along N, or vector size)
+-----------+-----------+-----------+-----------+-----------+
| thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M
+-----------+-----------+-----------+-----------+-----------+
| thrd_16 | thrd_17 | thrd_18 | thrd_19 | ...
+-----------+-----------+-----------+-----------+-----------+
// clang-format on
*/
template
<
typename
BlockTile_
,
// block size, seq<M, N>
typename
WarpPerBlock_
,
// num warps along seq<M, N>
typename
WarpTile_
,
// warp size, seq<M, N>
typename
Vector_
,
// contiguous pixels(vector size) along seq<M, N>
index_t
BlockSize_
=
warpSize
*
reduce_on_sequence
(
WarpPerBlock_
{}
,
multiplies
{}
,
number
<
1
>{})
>
struct
Generic2dBlockShape
{
// block size
static
constexpr
index_t
Block_M
=
BlockTile_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_N
=
BlockTile_
::
at
(
number
<
1
>
{});
// num warps along seq<M, N>, within each block
static
constexpr
index_t
WarpPerBlock_M
=
WarpPerBlock_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_N
=
WarpPerBlock_
::
at
(
number
<
1
>
{});
// warp size
static
constexpr
index_t
Warp_M
=
WarpTile_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_N
=
WarpTile_
::
at
(
number
<
1
>
{});
static_assert
(
Block_M
%
(
WarpPerBlock_M
*
Warp_M
)
==
0
);
static_assert
(
Block_N
%
(
WarpPerBlock_N
*
Warp_N
)
==
0
);
// repeat of each thread along seq<M, N>
static
constexpr
index_t
Repeat_M
=
Block_M
/
(
WarpPerBlock_M
*
Warp_M
);
static
constexpr
index_t
Repeat_N
=
Block_N
/
(
WarpPerBlock_N
*
Warp_N
);
// vector size along seq<M, N>
static
constexpr
index_t
Vector_M
=
Vector_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Vector_N
=
Vector_
::
at
(
number
<
1
>
{});
static_assert
(
Warp_M
%
Vector_M
==
0
);
static_assert
(
Warp_N
%
Vector_N
==
0
);
// num of threads along seq<M, N>, within each warp
static
constexpr
index_t
ThreadPerWarp_M
=
Warp_M
/
Vector_M
;
static
constexpr
index_t
ThreadPerWarp_N
=
Warp_N
/
Vector_N
;
static
constexpr
index_t
BlockSize
=
BlockSize_
;
};
}
// namespace ck_tile
include/ck_tile/ops/elementwise.hpp
View file @
3c305fee
...
@@ -4,4 +4,5 @@
...
@@ -4,4 +4,5 @@
#pragma once
#pragma once
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/epilogue.hpp
View file @
3c305fee
...
@@ -6,4 +6,5 @@
...
@@ -6,4 +6,5 @@
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp
View file @
3c305fee
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce
/block/block_reduce
.hpp"
#include "ck_tile/ops/reduce.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -18,12 +18,17 @@ struct DynamicQuantEpilogueTraits
...
@@ -18,12 +18,17 @@ struct DynamicQuantEpilogueTraits
};
};
// this epilogue just store out a M*N matrix, row major
// this epilogue just store out a M*N matrix, row major
template
<
typename
AccDataType_
,
typename
YScaleDataType_
,
typename
ODataType_
,
typename
Traits_
>
template
<
typename
AccDataType_
,
typename
YScaleDataType_
,
typename
ODataType_
,
typename
BlockShape_
,
typename
Traits_
>
struct
DynamicQuantEpilogueProblem
struct
DynamicQuantEpilogueProblem
{
{
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
// can consum generic 2d shape
using
Traits
=
remove_cvref_t
<
Traits_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
};
};
...
@@ -34,42 +39,81 @@ struct DynamicQuantEpilogue
...
@@ -34,42 +39,81 @@ struct DynamicQuantEpilogue
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
BlockShape
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
static
constexpr
bool
kPadM
=
Problem
::
Traits
::
kPadM
;
static
constexpr
bool
kPadM
=
Problem
::
Traits
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
UseRawStore
=
Problem
::
Traits
::
UseRawStore
;
static
constexpr
bool
UseRawStore
=
Problem
::
Traits
::
UseRawStore
;
static
constexpr
bool
UseMax3
=
Problem
::
Traits
::
UseMax3
;
static
constexpr
bool
UseMax3
=
Problem
::
Traits
::
UseMax3
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2d
()
{
using
P_
=
BlockReduce2dProblem
<
AccDataType
,
AccDataType
,
BlockShape
>
;
return
BlockReduce2d
<
P_
>
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dSync
()
{
using
P_
=
BlockReduce2dProblem
<
AccDataType
,
AccDataType
,
BlockShape
>
;
return
BlockReduce2dSync
<
P_
>
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dCrossWarpSync
()
{
using
P_
=
BlockReduce2dProblem
<
AccDataType
,
AccDataType
,
BlockShape
>
;
return
BlockReduce2dCrossWarpSync
<
P_
>
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
auto
reduce_crosswarp_sync
=
GetBlockReduce2dCrossWarpSync
();
return
reduce_crosswarp_sync
.
GetSmemSize
();
}
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
// how do we fix this ?
template
<
typename
ODramWindowTmp
,
typename
YScaleWindow
,
typename
OAccTile
>
template
<
typename
ODramWindowTmp
,
typename
YScaleWindow
,
typename
OAccTile
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
YScaleWindow
&
y_scale_window
,
YScaleWindow
&
y_scale_window
,
const
OAccTile
&
o_acc_tile
)
const
OAccTile
&
o_acc_tile
,
void
*
smem
)
{
{
// compute row max
auto
reduce
=
GetBlockReduce2d
();
auto
reduce_row_absmax
=
BlockReduce2D
{
o_acc_tile
,
type_convert
<
AccDataType
>
(
0
)};
auto
reduce_sync
=
GetBlockReduce2dSync
();
auto
reduce_crosswarp_sync
=
GetBlockReduce2dCrossWarpSync
();
const
auto
f_absmax
=
[](
auto
acc_
,
auto
v_0_
)
{
return
max
(
acc_
,
abs
(
v_0_
));
};
auto
row_absmax
=
[
&
]()
{
auto
row_absmax
=
[
&
]()
{
if
constexpr
(
UseMax3
&&
std
::
is_same_v
<
AccDataType
,
float
>
)
constexpr
auto
y_size_per_row
=
OAccTile
{}.
get_tile_distribution
().
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
1
>
{});
// constexpr auto y_size_per_row = OAccTile::get_lengths()[number<1>{}];
if
constexpr
(
UseMax3
&&
std
::
is_same_v
<
AccDataType
,
float
>
&&
y_size_per_row
%
2
==
0
)
{
{
const
auto
f_max
=
[](
auto
acc_
,
auto
v_0_
)
{
return
max
(
acc_
,
abs
(
v_0_
));
};
// fast max3 implementation
// const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) {
const
auto
f_max3
=
[](
auto
acc_
,
auto
v_0_
,
auto
v_1_
)
{
// float rtn;
float
rtn
;
// asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
asm
volatile
(
"v_max3_f32 %0, %1, abs(%2), abs(%3)"
// : "=v"(rtn)
:
"=v"
(
rtn
)
// : "v"(acc_), "v"(v_0_), "v"(v_1_));
:
"v"
(
acc_
),
"v"
(
v_0_
),
"v"
(
v_1_
));
// return rtn;
return
rtn
;
// };
};
// return reduce_row_absmax(f_max3, f_max, sequence<1, 2>{});
return
reduce
(
o_acc_tile
,
type_convert
<
AccDataType
>
(
0
),
f_max3
,
sequence
<
1
,
2
>
{});
return
reduce_row_absmax
(
f_max
);
}
}
else
else
{
{
const
auto
f_max
=
[](
auto
acc_
,
auto
v_0_
)
{
return
max
(
acc_
,
abs
(
v_0_
));
};
return
reduce
(
o_acc_tile
,
type_convert
<
AccDataType
>
(
0
),
f_absmax
);
return
reduce_row_absmax
(
f_max
);
}
}
}();
}();
reduce_sync
(
row_absmax
,
f_absmax
);
reduce_crosswarp_sync
(
row_absmax
,
smem
,
f_absmax
);
#if 0
sweep_tile(row_absmax, [&](auto idx) {
auto ddd = row_absmax[idx];
printf("tid:%d, absmax:%f\n", static_cast<int>(threadIdx.x), ddd);
});
#endif
// here y_scale is Acc TYpe, need convert to YScale type later
// here y_scale is Acc TYpe, need convert to YScale type later
auto
y_scale
=
tile_elementwise_in
(
auto
y_scale
=
tile_elementwise_in
(
...
@@ -80,15 +124,23 @@ struct DynamicQuantEpilogue
...
@@ -80,15 +124,23 @@ struct DynamicQuantEpilogue
store_tile
(
y_scale_window
,
cast_tile
<
YScaleDataType
>
(
y_scale
));
store_tile
(
y_scale_window
,
cast_tile
<
YScaleDataType
>
(
y_scale
));
auto
o_acc_scaled_tile
=
make_static_distributed_tensor
<
AccDataType
>
(
o_acc_tile
.
get_tile_distribution
());
sweep_tile
(
o_acc_tile
,
[
&
](
auto
idx
)
{
constexpr
auto
row_id
=
make_tuple
(
idx
[
number
<
0
>
{}]);
o_acc_scaled_tile
(
idx
)
=
o_acc_tile
[
idx
]
/
y_scale
(
row_id
);
});
// TODO: this is ugly
// TODO: this is ugly
if
constexpr
(
UseRawStore
&&
(
kPadM
||
kPadN
))
if
constexpr
(
UseRawStore
&&
(
kPadM
||
kPadN
))
{
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_
scaled_
tile
));
buffer_store_fence
();
buffer_store_fence
();
}
}
else
else
{
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_
scaled_
tile
));
}
}
}
}
};
};
...
...
include/ck_tile/ops/fmha.hpp
View file @
3c305fee
...
@@ -43,4 +43,5 @@
...
@@ -43,4 +43,5 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/gemm.hpp
View file @
3c305fee
...
@@ -37,4 +37,5 @@
...
@@ -37,4 +37,5 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/image_to_column.hpp
View file @
3c305fee
...
@@ -6,4 +6,5 @@
...
@@ -6,4 +6,5 @@
#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp"
#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp"
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/layernorm2d.hpp
View file @
3c305fee
...
@@ -10,4 +10,5 @@
...
@@ -10,4 +10,5 @@
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
View file @
3c305fee
...
@@ -147,7 +147,7 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -147,7 +147,7 @@ struct Layernorm2dFwdPipelineOnePass
if
constexpr
(
kFusedSweep
==
Layernorm2dFusedSweepEnum
::
DYNAMIC_QUANT
)
if
constexpr
(
kFusedSweep
==
Layernorm2dFusedSweepEnum
::
DYNAMIC_QUANT
)
{
{
Epilogue
{}(
y_window_
,
y_scale_window
,
ln
);
Epilogue
{}(
y_window_
,
y_scale_window
,
ln
,
smem
);
}
}
else
else
Epilogue
{}(
y_window_
,
ln
);
Epilogue
{}(
y_window_
,
ln
);
...
...
include/ck_tile/ops/permute.hpp
View file @
3c305fee
...
@@ -5,4 +5,5 @@
...
@@ -5,4 +5,5 @@
#include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp"
#include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp"
#include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
#include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/reduce.hpp
View file @
3c305fee
...
@@ -7,4 +7,5 @@
...
@@ -7,4 +7,5 @@
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/reduce/block/block_reduce.hpp
View file @
3c305fee
...
@@ -301,7 +301,10 @@ struct BlockReduce2D
...
@@ -301,7 +301,10 @@ struct BlockReduce2D
.
get_static_tile_distribution_encoding
(),
.
get_static_tile_distribution_encoding
(),
ReduceDim
{}));
ReduceDim
{}));
return
make_static_distributed_tensor
<
InDataType
>
(
acc_dstr
);
auto
dst_
=
make_static_distributed_tensor
<
InDataType
>
(
acc_dstr
);
// init acc_tensor
tile_elementwise_inout
([
&
](
auto
&
x_
)
{
x_
=
type_convert
<
InDataType
>
(
reduce_init
);
},
dst_
);
return
dst_
;
}
}
// return number of pixels each lane need to reduce
// return number of pixels each lane need to reduce
...
...
include/ck_tile/ops/reduce/block/block_reduce2d.hpp
View file @
3c305fee
...
@@ -17,14 +17,24 @@ struct BlockReduce2d
...
@@ -17,14 +17,24 @@ struct BlockReduce2d
CK_TILE_DEVICE
constexpr
BlockReduce2d
()
{}
CK_TILE_DEVICE
constexpr
BlockReduce2d
()
{}
template
<
typename
XDistributedTensor_
,
typename
YDistributedTensor_
,
typename
ReduceFunc
>
template
<
typename
XDistributedTensor_
,
typename
YDistributedTensor_
,
typename
ReduceFunc
,
typename
ReducePacksPerXDim
=
uniform_sequence_gen_t
<
2
,
1
>
>
CK_TILE_DEVICE
void
operator
()(
const
XDistributedTensor_
&
x_tensor
,
CK_TILE_DEVICE
void
operator
()(
const
XDistributedTensor_
&
x_tensor
,
YDistributedTensor_
&
y_tensor
,
YDistributedTensor_
&
y_tensor
,
const
ReduceFunc
&
reduce_func
)
const
ReduceFunc
&
reduce_func
,
ReducePacksPerXDim
=
{})
{
{
sweep_tile
<
XDistributedTensor_
>
(
[
&
](
auto
...
idx_
)
{
constexpr
auto
idx_0
=
make_tuple
(
make_tuple
(
idx_
[
number
<
0
>
{}]...)[
number
<
0
>
{}]);
y_tensor
(
idx_0
)
=
reduce_func
(
y_tensor
(
idx_0
),
x_tensor
[
idx_
]...);
},
ReducePacksPerXDim
{});
#if 0
constexpr auto I0 = number<0>{};
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
constexpr auto I1 = number<1>{};
constexpr auto spans = XDistributedTensor_::get_distributed_spans();
constexpr auto spans = XDistributedTensor_::get_distributed_spans();
// FIXME: hard coded to reduce 2nd axis
// FIXME: hard coded to reduce 2nd axis
...
@@ -42,6 +52,7 @@ struct BlockReduce2d
...
@@ -42,6 +52,7 @@ struct BlockReduce2d
y_tensor(y_dstr_idx) = y;
y_tensor(y_dstr_idx) = y;
});
});
#endif
}
}
template
<
typename
XDistributedTensor_
>
template
<
typename
XDistributedTensor_
>
...
@@ -63,14 +74,17 @@ struct BlockReduce2d
...
@@ -63,14 +74,17 @@ struct BlockReduce2d
return
tensor
;
return
tensor
;
}
}
template
<
typename
XDistributedTensor_
,
typename
ReduceFunc
>
template
<
typename
XDistributedTensor_
,
typename
ReduceFunc
,
typename
ReducePacksPerXDim
=
uniform_sequence_gen_t
<
2
,
1
>
>
CK_TILE_DEVICE
auto
operator
()(
const
XDistributedTensor_
&
x_tensor
,
CK_TILE_DEVICE
auto
operator
()(
const
XDistributedTensor_
&
x_tensor
,
const
ComputeDataType
&
reduce_init
,
const
ComputeDataType
&
reduce_init
,
const
ReduceFunc
&
reduce_func
)
const
ReduceFunc
&
reduce_func
,
ReducePacksPerXDim
=
{})
{
{
auto
y_tensor
=
MakeYBlockTile
<
XDistributedTensor_
>
();
auto
y_tensor
=
MakeYBlockTile
<
XDistributedTensor_
>
();
set_tile
(
y_tensor
,
reduce_init
);
set_tile
(
y_tensor
,
reduce_init
);
(
*
this
)(
x_tensor
,
y_tensor
,
reduce_func
);
(
*
this
)(
x_tensor
,
y_tensor
,
reduce_func
,
ReducePacksPerXDim
{}
);
return
y_tensor
;
return
y_tensor
;
}
}
...
...
include/ck_tile/ops/rmsnorm2d.hpp
View file @
3c305fee
...
@@ -9,4 +9,5 @@
...
@@ -9,4 +9,5 @@
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/softmax.hpp
View file @
3c305fee
...
@@ -5,4 +5,5 @@
...
@@ -5,4 +5,5 @@
#include "ck_tile/ops/softmax/block/block_softmax_2d.hpp"
#include "ck_tile/ops/softmax/block/block_softmax_2d.hpp"
#include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp"
#include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/topk.hpp
View file @
3c305fee
...
@@ -5,4 +5,5 @@
...
@@ -5,4 +5,5 @@
#include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp"
#include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp"
#include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp"
#include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment