Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e0594d08
Unverified
Commit
e0594d08
authored
Nov 06, 2024
by
Illia Silin
Committed by
GitHub
Nov 06, 2024
Browse files
Merge pull request #214 from ROCm/merge_from_public
Merge from public
parents
7d50244e
667cd6ab
Changes
121
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
678 additions
and
148 deletions
+678
-148
include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp
include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp
+66
-18
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+2
-2
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
...rm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
+6
-5
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
+10
-23
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
...layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
+1
-1
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
+11
-10
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
..._tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
+1
-1
include/ck_tile/ops/reduce/block/block_reduce2d.hpp
include/ck_tile/ops/reduce/block/block_reduce2d.hpp
+2
-1
include/ck_tile/ops/rmsnorm2d.hpp
include/ck_tile/ops/rmsnorm2d.hpp
+0
-1
include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp
...ude/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp
+6
-6
include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp
include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp
+0
-78
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp
...norm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp
+1
-0
include/ck_tile/ops/smoothquant.hpp
include/ck_tile/ops/smoothquant.hpp
+12
-0
include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp
...ude/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp
+176
-0
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp
...othquant/pipeline/smoothquant_pipeline_default_policy.hpp
+95
-0
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp
...ps/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp
+94
-0
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp
...ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp
+35
-0
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp
...ps/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp
+132
-0
include/ck_tile/remod.py
include/ck_tile/remod.py
+3
-2
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp
...ed_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp
+25
-0
No files found.
include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp
View file @
e0594d08
...
@@ -8,17 +8,23 @@
...
@@ -8,17 +8,23 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
bool
kPadM_
,
bool
kPadN_
,
bool
UseRawStore_
=
true
,
bool
UseMax3_
=
false
>
template
<
bool
kPadM_
,
bool
kPadN_
,
bool
UseSmoothInputScale_
,
bool
UseRawStore_
=
true
,
bool
UseMax3_
=
false
>
struct
DynamicQuantEpilogueTraits
struct
DynamicQuantEpilogueTraits
{
{
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
UseRawStore
=
UseRawStore_
;
static
constexpr
bool
UseSmoothInputScale
=
UseSmoothInputScale_
;
static
constexpr
bool
UseMax3
=
UseMax3_
;
static
constexpr
bool
UseRawStore
=
UseRawStore_
;
static
constexpr
bool
UseMax3
=
UseMax3_
;
};
};
// this epilogue just store out a M*N matrix, row major
// this epilogue just store out a M*N matrix, row major
template
<
typename
AccDataType_
,
template
<
typename
AccDataType_
,
typename
XScaleDataType_
,
typename
YScaleDataType_
,
typename
YScaleDataType_
,
typename
ODataType_
,
typename
ODataType_
,
typename
BlockShape_
,
typename
BlockShape_
,
...
@@ -26,17 +32,20 @@ template <typename AccDataType_,
...
@@ -26,17 +32,20 @@ template <typename AccDataType_,
struct
DynamicQuantEpilogueProblem
struct
DynamicQuantEpilogueProblem
{
{
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
XScaleDataType
=
remove_cvref_t
<
XScaleDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
// can consum generic 2d shape
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
// can consum generic 2d shape
using
Traits
=
remove_cvref_t
<
Traits_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
};
};
// TODO: we should put descriptor creation function into policy
template
<
typename
Problem_
,
typename
Policy_
=
void
>
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
DynamicQuantEpilogue
struct
DynamicQuantEpilogue
{
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
XScaleDataType
=
remove_cvref_t
<
typename
Problem
::
XScaleDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
BlockShape
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
BlockShape
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
...
@@ -63,6 +72,33 @@ struct DynamicQuantEpilogue
...
@@ -63,6 +72,33 @@ struct DynamicQuantEpilogue
return
BlockReduce2dCrossWarpSync
<
P_
>
{};
return
BlockReduce2dCrossWarpSync
<
P_
>
{};
}
}
CK_TILE_DEVICE
static
constexpr
auto
MakeSmoothInputScaleTileDistribution
()
{
using
S
=
BlockShape
;
#if 0
// don't remove this
// Note that if we set encoding purposely like this, you will result in compile fail
// TODO: x_scale create local-scratch to accept arbitrary acc input (with same length)
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>,
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<0, 1, 1>,
sequence<0, 0, 3>>{});
#else
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
>
,
tuple
<
sequence
<
S
::
Repeat_N
,
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
,
S
::
Vector_N
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
3
>>
{});
#endif
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
{
auto
reduce_crosswarp_sync
=
GetBlockReduce2dCrossWarpSync
();
auto
reduce_crosswarp_sync
=
GetBlockReduce2dCrossWarpSync
();
...
@@ -71,8 +107,12 @@ struct DynamicQuantEpilogue
...
@@ -71,8 +107,12 @@ struct DynamicQuantEpilogue
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
// how do we fix this ?
template
<
typename
ODramWindowTmp
,
typename
YScaleWindow
,
typename
OAccTile
>
template
<
typename
ODramWindowTmp
,
typename
XScaleWindow
,
typename
YScaleWindow
,
typename
OAccTile
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
const
XScaleWindow
&
x_scale_window_
,
YScaleWindow
&
y_scale_window
,
YScaleWindow
&
y_scale_window
,
const
OAccTile
&
o_acc_tile
,
const
OAccTile
&
o_acc_tile
,
void
*
smem
)
void
*
smem
)
...
@@ -80,6 +120,18 @@ struct DynamicQuantEpilogue
...
@@ -80,6 +120,18 @@ struct DynamicQuantEpilogue
auto
reduce
=
GetBlockReduce2d
();
auto
reduce
=
GetBlockReduce2d
();
auto
reduce_sync
=
GetBlockReduce2dSync
();
auto
reduce_sync
=
GetBlockReduce2dSync
();
auto
reduce_crosswarp_sync
=
GetBlockReduce2dCrossWarpSync
();
auto
reduce_crosswarp_sync
=
GetBlockReduce2dCrossWarpSync
();
const
auto
x_scale_window
=
make_tile_window
(
x_scale_window_
,
MakeSmoothInputScaleTileDistribution
());
auto
x_scale
=
load_tile
(
x_scale_window
);
auto
o_acc_tmp
=
o_acc_tile
;
sweep_tile
(
o_acc_tmp
,
[
&
](
auto
idx
)
{
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
xs_
=
type_convert
<
AccDataType
>
(
x_scale
[
j_idx
]);
o_acc_tmp
(
idx
)
=
o_acc_tmp
(
idx
)
*
xs_
;
});
const
auto
f_absmax
=
[](
auto
acc_
,
auto
v_0_
)
{
return
max
(
acc_
,
abs
(
v_0_
));
};
const
auto
f_absmax
=
[](
auto
acc_
,
auto
v_0_
)
{
return
max
(
acc_
,
abs
(
v_0_
));
};
...
@@ -87,10 +139,9 @@ struct DynamicQuantEpilogue
...
@@ -87,10 +139,9 @@ struct DynamicQuantEpilogue
constexpr
auto
y_size_per_row
=
constexpr
auto
y_size_per_row
=
OAccTile
{}.
get_tile_distribution
().
get_ys_to_d_descriptor
().
get_lengths
().
at
(
OAccTile
{}.
get_tile_distribution
().
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
1
>
{});
number
<
1
>
{});
// constexpr auto y_size_per_row = OAccTile::get_lengths()[number<1>{}];
if
constexpr
(
UseMax3
&&
std
::
is_same_v
<
AccDataType
,
float
>
&&
y_size_per_row
%
2
==
0
)
if
constexpr
(
UseMax3
&&
std
::
is_same_v
<
AccDataType
,
float
>
&&
y_size_per_row
%
2
==
0
)
{
{
// fast max3 implementation
// fast max3
+abs
implementation
const
auto
f_max3
=
[](
auto
acc_
,
auto
v_0_
,
auto
v_1_
)
{
const
auto
f_max3
=
[](
auto
acc_
,
auto
v_0_
,
auto
v_1_
)
{
float
rtn
;
float
rtn
;
asm
volatile
(
"v_max3_f32 %0, %1, abs(%2), abs(%3)"
asm
volatile
(
"v_max3_f32 %0, %1, abs(%2), abs(%3)"
...
@@ -98,11 +149,11 @@ struct DynamicQuantEpilogue
...
@@ -98,11 +149,11 @@ struct DynamicQuantEpilogue
:
"v"
(
acc_
),
"v"
(
v_0_
),
"v"
(
v_1_
));
:
"v"
(
acc_
),
"v"
(
v_0_
),
"v"
(
v_1_
));
return
rtn
;
return
rtn
;
};
};
return
reduce
(
o_acc_t
ile
,
type_convert
<
AccDataType
>
(
0
),
f_max3
,
sequence
<
1
,
2
>
{});
return
reduce
(
o_acc_t
mp
,
type_convert
<
AccDataType
>
(
0
),
f_max3
,
sequence
<
1
,
2
>
{});
}
}
else
else
{
{
return
reduce
(
o_acc_t
ile
,
type_convert
<
AccDataType
>
(
0
),
f_absmax
);
return
reduce
(
o_acc_t
mp
,
type_convert
<
AccDataType
>
(
0
),
f_absmax
);
}
}
}();
}();
reduce_sync
(
row_absmax
,
f_absmax
);
reduce_sync
(
row_absmax
,
f_absmax
);
...
@@ -117,23 +168,20 @@ struct DynamicQuantEpilogue
...
@@ -117,23 +168,20 @@ struct DynamicQuantEpilogue
store_tile
(
y_scale_window
,
cast_tile
<
YScaleDataType
>
(
y_scale
));
store_tile
(
y_scale_window
,
cast_tile
<
YScaleDataType
>
(
y_scale
));
auto
o_acc_scaled_tile
=
sweep_tile
(
o_acc_tmp
,
[
&
](
auto
idx
)
{
make_static_distributed_tensor
<
AccDataType
>
(
o_acc_tile
.
get_tile_distribution
());
constexpr
auto
row_id
=
make_tuple
(
idx
[
number
<
0
>
{}]);
o_acc_tmp
(
idx
)
=
o_acc_tmp
[
idx
]
/
y_scale
(
row_id
);
sweep_tile
(
o_acc_tile
,
[
&
](
auto
idx
)
{
constexpr
auto
row_id
=
make_tuple
(
idx
[
number
<
0
>
{}]);
o_acc_scaled_tile
(
idx
)
=
o_acc_tile
[
idx
]
/
y_scale
(
row_id
);
});
});
// TODO: this is ugly
// TODO: this is ugly
if
constexpr
(
UseRawStore
&&
(
kPadM
||
kPadN
))
if
constexpr
(
UseRawStore
&&
(
kPadM
||
kPadN
))
{
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_
scaled_tile
));
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_
tmp
));
buffer_store_fence
();
buffer_store_fence
();
}
}
else
else
{
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_
scaled_tile
));
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_
tmp
));
}
}
}
}
};
};
...
...
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
e0594d08
...
@@ -117,7 +117,7 @@ struct Layernorm2dFwd
...
@@ -117,7 +117,7 @@ struct Layernorm2dFwd
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
{
{
return
(
hargs
.
m
+
Block_M
-
1
)
/
Block_M
;
return
dim3
(
integer_divide_ceil
(
hargs
.
m
,
Block_M
))
;
}
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockShape
::
BlockSize
;
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockShape
::
BlockSize
;
}
...
@@ -165,7 +165,7 @@ struct Layernorm2dFwd
...
@@ -165,7 +165,7 @@ struct Layernorm2dFwd
return
base_str
;
return
base_str
;
}();
}();
return
_SS_
(
"layernorm2d_fwd_"
)
+
_SS_
(
prec_str
)
+
"_"
+
return
_SS_
(
"layernorm2d_fwd_"
)
+
_SS_
(
prec_str
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
_SS_
(
Pipeline
::
name
)
+
surfix
;
_SS_
(
Pipeline
::
name
)
+
surfix
;
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
View file @
e0594d08
...
@@ -26,6 +26,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
...
@@ -26,6 +26,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
sequence
<
1
,
1
,
2
,
2
>
,
sequence
<
1
,
1
,
2
,
2
>
,
sequence
<
0
,
3
,
0
,
3
>>
{});
sequence
<
0
,
3
,
0
,
3
>>
{});
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeGammaBetaBlockTileDistribution
()
CK_TILE_DEVICE
static
constexpr
auto
MakeGammaBetaBlockTileDistribution
()
{
{
...
@@ -44,7 +45,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
...
@@ -44,7 +45,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockWelford
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockWelford
()
{
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
X
DataType
,
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
typename
Problem
::
BlockShape
>
;
...
@@ -54,7 +55,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
...
@@ -54,7 +55,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockWelfordSync
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockWelfordSync
()
{
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
X
DataType
,
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
typename
Problem
::
BlockShape
>
;
...
@@ -64,7 +65,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
...
@@ -64,7 +65,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockWelfordCrossWarpSync
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockWelfordCrossWarpSync
()
{
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
X
DataType
,
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
typename
Problem
::
BlockShape
>
;
...
@@ -76,13 +77,13 @@ struct Layernorm2dFwdPipelineDefaultPolicy
...
@@ -76,13 +77,13 @@ struct Layernorm2dFwdPipelineDefaultPolicy
{
{
if
constexpr
(
Problem
::
kNeedCrossWarpSync
)
if
constexpr
(
Problem
::
kNeedCrossWarpSync
)
{
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
X
DataType
,
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
typename
Problem
::
BlockShape
>
;
using
block_welford
=
BlockWelford
<
P_
>
;
using
block_welford
=
BlockWelford
<
P_
>
;
using
x_block_tile
=
using
x_block_tile
=
decltype
(
make_static_distributed_tensor
<
typename
Problem
::
X
DataType
>
(
decltype
(
make_static_distributed_tensor
<
typename
Problem
::
Compute
DataType
>
(
MakeXBlockTileDistribution
<
Problem
>
()));
MakeXBlockTileDistribution
<
Problem
>
()));
using
mean_var_block_tile
=
using
mean_var_block_tile
=
decltype
(
block_welford
::
template
MakeMeanVarBlockTile
<
x_block_tile
>());
decltype
(
block_welford
::
template
MakeMeanVarBlockTile
<
x_block_tile
>());
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
View file @
e0594d08
...
@@ -87,12 +87,9 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -87,12 +87,9 @@ struct Layernorm2dFwdPipelineOnePass
x_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
x_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
y_residual_window
=
make_tile_window
(
auto
y_residual_window
=
make_tile_window
(
y_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
y_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
const
auto
x_scale_window
=
make_tile_window
(
x_scale_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
auto
x
=
load_tile
(
x_window
);
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
x_scale
=
load_tile
(
x_scale_window
);
int
cur_count
=
0
;
int
cur_count
=
0
;
int
max_count
=
int
max_count
=
...
@@ -106,20 +103,21 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -106,20 +103,21 @@ struct Layernorm2dFwdPipelineOnePass
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
beta
=
load_tile
(
beta_window
);
const
auto
beta
=
load_tile
(
beta_window
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
{
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
// compute x = x_resi + x
// compute x = x_resi + x
x
(
idx
)
=
type_convert
<
YResidualDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
type_convert
<
YResidualDataType
>
(
x
(
idx
));
});
});
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
)
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
)
store_tile
(
y_residual_window
,
x
);
store_tile
(
y_residual_window
,
cast_tile
<
YResidualDataType
>
(
acc
)
);
}
}
// compute welford each-thread->cross-lane->cross-warp
// compute welford each-thread->cross-lane->cross-warp
auto
[
mean
,
var
]
=
block_welford
(
x
,
cur_count
,
max_count
);
auto
[
mean
,
var
]
=
block_welford
(
acc
,
cur_count
,
max_count
);
block_welford_sync
(
mean
,
var
,
cur_count
);
block_welford_sync
(
mean
,
var
,
cur_count
);
block_welford_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_welford_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_tile_welford_post_scale_var
(
var
,
cur_count
);
block_tile_welford_post_scale_var
(
var
,
cur_count
);
...
@@ -137,7 +135,7 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -137,7 +135,7 @@ struct Layernorm2dFwdPipelineOnePass
store_tile
(
inv_std_window
,
cast_tile
<
InvStdDataType
>
(
inv_std
));
store_tile
(
inv_std_window
,
cast_tile
<
InvStdDataType
>
(
inv_std
));
// layernorm computation
// layernorm computation
auto
ln
=
make_static_distributed_tensor
<
ComputeDataType
>
(
x
.
get_tile_distribution
());
auto
ln
=
make_static_distributed_tensor
<
ComputeDataType
>
(
acc
.
get_tile_distribution
());
sweep_tile
(
ln
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
sweep_tile
(
ln
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
...
@@ -145,26 +143,15 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -145,26 +143,15 @@ struct Layernorm2dFwdPipelineOnePass
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
beta_
=
type_convert
<
ComputeDataType
>
(
beta
[
j_idx
]);
const
auto
beta_
=
type_convert
<
ComputeDataType
>
(
beta
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
ln_
=
(
acc
[
idx
]
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
auto
ln_
=
(
x_
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
ln
(
idx
)
=
ln_
;
ln
(
idx
)
=
ln_
;
});
});
if
constexpr
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
// smooth-quant pre-scale, then run rowwise-quant
sweep_tile
(
ln
,
[
&
](
auto
idx
)
{
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
xs_
=
type_convert
<
ComputeDataType
>
(
x_scale
[
j_idx
]);
ln
(
idx
)
=
ln
(
idx
)
*
xs_
;
});
}
if
constexpr
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
DYNAMIC_QUANT
||
if
constexpr
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
DYNAMIC_QUANT
||
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
{
Epilogue
{}(
y_window_
,
y_scale_window
,
ln
,
smem
);
Epilogue
{}(
y_window_
,
x_scale_window_
,
y_scale_window
,
ln
,
smem
);
}
}
else
else
Epilogue
{}(
y_window_
,
ln
);
Epilogue
{}(
y_window_
,
ln
);
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
View file @
e0594d08
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
View file @
e0594d08
...
@@ -106,7 +106,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -106,7 +106,7 @@ struct Layernorm2dFwdPipelineTwoPass
auto
block_welford_cross_warp_sync
=
auto
block_welford_cross_warp_sync
=
Policy
::
template
GetBlockWelfordCrossWarpSync
<
Problem
>();
Policy
::
template
GetBlockWelfordCrossWarpSync
<
Problem
>();
using
XTensorType
=
decltype
(
load_tile
(
x_window
));
using
XTensorType
=
decltype
(
cast_tile
<
ComputeDataType
>
(
load_tile
(
x_window
))
)
;
auto
mean
=
block_welford
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
auto
mean
=
block_welford
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
auto
var
=
block_welford
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
auto
var
=
block_welford
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
...
@@ -117,21 +117,22 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -117,21 +117,22 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
Block_N
});
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
{
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
// compute x = x_resi + x
// compute x = x_resi + x
x
(
idx
)
=
type_convert
<
YResidualDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
type_convert
<
YResidualDataType
>
(
x
(
idx
));
});
});
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
)
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
{
store_tile
(
y_residual_window
,
x
);
store_tile
(
y_residual_window
,
cast_tile
<
YResidualDataType
>
(
acc
)
);
move_tile_window
(
y_residual_window
,
{
0
,
Block_N
});
move_tile_window
(
y_residual_window
,
{
0
,
Block_N
});
}
}
}
}
block_welford
(
x
,
mean
,
var
,
cur_count
,
max_count
);
block_welford
(
acc
,
mean
,
var
,
cur_count
,
max_count
);
}
}
block_welford_sync
(
mean
,
var
,
cur_count
);
block_welford_sync
(
mean
,
var
,
cur_count
);
...
@@ -165,20 +166,21 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -165,20 +166,21 @@ struct Layernorm2dFwdPipelineTwoPass
{
{
auto
x
=
load_tile
(
x_window
);
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
{
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
// compute x = x_resi + x
// compute x = x_resi + x
x
(
idx
)
=
type_convert
<
YResidualDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
type_convert
<
YResidualDataType
>
(
x
(
idx
));
});
});
}
}
// load gamma/beta (TODO: support no gamma/beta?)
// load gamma/beta (TODO: support no gamma/beta?)
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
beta
=
load_tile
(
beta_window
);
const
auto
beta
=
load_tile
(
beta_window
);
auto
ln
=
make_static_distributed_tensor
<
ComputeDataType
>
(
x
.
get_tile_distribution
());
auto
ln
=
make_static_distributed_tensor
<
ComputeDataType
>
(
acc
.
get_tile_distribution
());
sweep_tile
(
ln
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
sweep_tile
(
ln
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
...
@@ -187,8 +189,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -187,8 +189,7 @@ struct Layernorm2dFwdPipelineTwoPass
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
beta_
=
type_convert
<
ComputeDataType
>
(
beta
[
j_idx
]);
const
auto
beta_
=
type_convert
<
ComputeDataType
>
(
beta
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
ln_
=
(
acc
(
idx
)
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
auto
ln_
=
(
x_
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
ln
(
idx
)
=
ln_
;
ln
(
idx
)
=
ln_
;
});
});
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
View file @
e0594d08
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/reduce/block/block_reduce2d.hpp
View file @
e0594d08
...
@@ -29,7 +29,8 @@ struct BlockReduce2d
...
@@ -29,7 +29,8 @@ struct BlockReduce2d
sweep_tile
<
XDistributedTensor_
>
(
sweep_tile
<
XDistributedTensor_
>
(
[
&
](
auto
...
idx_
)
{
[
&
](
auto
...
idx_
)
{
constexpr
auto
idx_0
=
make_tuple
(
make_tuple
(
idx_
[
number
<
0
>
{}]...)[
number
<
0
>
{}]);
constexpr
auto
idx_0
=
make_tuple
(
make_tuple
(
idx_
[
number
<
0
>
{}]...)[
number
<
0
>
{}]);
y_tensor
(
idx_0
)
=
reduce_func
(
y_tensor
(
idx_0
),
x_tensor
[
idx_
]...);
y_tensor
(
idx_0
)
=
reduce_func
(
y_tensor
(
idx_0
),
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_tensor
[
idx_
])...);
},
},
ReducePacksPerXDim
{});
ReducePacksPerXDim
{});
#if 0
#if 0
...
...
include/ck_tile/ops/rmsnorm2d.hpp
View file @
e0594d08
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
#pragma once
#pragma once
#include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp"
#include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp"
#include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp"
...
...
include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp
View file @
e0594d08
...
@@ -11,11 +11,11 @@ namespace ck_tile {
...
@@ -11,11 +11,11 @@ namespace ck_tile {
// host side args
// host side args
struct
Rmsnorm2dFwdHostArgs
struct
Rmsnorm2dFwdHostArgs
{
{
const
void
*
p_x
;
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_gamma
;
const
void
*
p_gamma
;
// [1, n], gamma, prec same as input
void
*
p_y
;
void
*
p_y
;
// [m, n], output, fp16/bf16
void
*
p_invRms
;
void
*
p_invRms
;
// [m, 1], output inv-rms, prec same as input, nullptr if not used
float
epsilon
;
float
epsilon
;
...
@@ -83,7 +83,7 @@ struct Rmsnorm2dFwd
...
@@ -83,7 +83,7 @@ struct Rmsnorm2dFwd
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
{
{
return
(
hargs
.
m
+
Block_M
-
1
)
/
Block_M
;
return
dim3
(
integer_divide_ceil
(
hargs
.
m
,
Block_M
))
;
}
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockShape
::
BlockSize
;
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockShape
::
BlockSize
;
}
...
@@ -149,7 +149,7 @@ struct Rmsnorm2dFwd
...
@@ -149,7 +149,7 @@ struct Rmsnorm2dFwd
number
<
1
>
{});
number
<
1
>
{});
const
auto
tmp2_
=
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
kPad
M
>
{});
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
kPad
N
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_N
>
{}),
{
0
});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_N
>
{}),
{
0
});
}();
}();
...
...
include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp
deleted
100644 → 0
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
/*
// clang-format off
4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector
Block_N (Warp_N * WarpPerBlock_N * Repeat_N )
+<----------------------< Repeat_N(2)>--------------------->+
| |
+<-- <WarpPerBlock_N(2)> -->+
Warp_N
+--------------+--------------+--------------+--------------+----+----------------+
Warp_M | wrap_0 | wrap_1 | | ^ ^
+--------------+--------------+ | <WarpPerBlock_M(2)> |
| wrap_2 | wrap_3 | | v
+--------------+--------------+--------------+--------------+----+ Block_M
| | |
+ + |
| | | v
+--------------+--------------+--------------+--------------+ +
each Warp-tile (e.g 16 thrd per row)
Vector_N (contiguous pixels each thrd holds along N, or vector size)
+-----------+-----------+-----------+-----------+-----------+
| thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M
+-----------+-----------+-----------+-----------+-----------+
| thrd_16 | thrd_17 | thrd_18 | thrd_19 | ...
+-----------+-----------+-----------+-----------+-----------+
// clang-format on
*/
template
<
typename
BlockTile_
,
// block size, seq<M, N>
typename
WarpPerBlock_
,
// num warps along seq<M, N>
typename
WarpTile_
,
// warp size, seq<M, N>
typename
Vector_
,
// contiguous pixels(vector size) along seq<M, N>
index_t
BlockSize_
=
warpSize
*
reduce_on_sequence
(
WarpPerBlock_
{}
,
multiplies
{}
,
number
<
1
>{})
>
struct
Rmsnorm2dShape
{
// block size
static
constexpr
index_t
Block_M
=
BlockTile_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_N
=
BlockTile_
::
at
(
number
<
1
>
{});
// num warps along seq<M, N>, within each block
static
constexpr
index_t
WarpPerBlock_M
=
WarpPerBlock_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_N
=
WarpPerBlock_
::
at
(
number
<
1
>
{});
// warp size
static
constexpr
index_t
Warp_M
=
WarpTile_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_N
=
WarpTile_
::
at
(
number
<
1
>
{});
static_assert
(
Block_M
%
(
WarpPerBlock_M
*
Warp_M
)
==
0
);
static_assert
(
Block_N
%
(
WarpPerBlock_N
*
Warp_N
)
==
0
);
// repeat of each thread along seq<M, N>
static
constexpr
index_t
Repeat_M
=
Block_M
/
(
WarpPerBlock_M
*
Warp_M
);
static
constexpr
index_t
Repeat_N
=
Block_N
/
(
WarpPerBlock_N
*
Warp_N
);
// vector size along seq<M, N>
static
constexpr
index_t
Vector_M
=
Vector_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Vector_N
=
Vector_
::
at
(
number
<
1
>
{});
static_assert
(
Warp_M
%
Vector_M
==
0
);
static_assert
(
Warp_N
%
Vector_N
==
0
);
// num of threads along seq<M, N>, within each warp
static
constexpr
index_t
ThreadPerWarp_M
=
Warp_M
/
Vector_M
;
static
constexpr
index_t
ThreadPerWarp_N
=
Warp_N
/
Vector_N
;
static
constexpr
index_t
BlockSize
=
BlockSize_
;
};
}
// namespace ck_tile
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp
View file @
e0594d08
...
@@ -26,6 +26,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
...
@@ -26,6 +26,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
sequence
<
1
,
1
,
2
,
2
>
,
sequence
<
1
,
1
,
2
,
2
>
,
sequence
<
0
,
3
,
0
,
3
>>
{});
sequence
<
0
,
3
,
0
,
3
>>
{});
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeGammaBlockTileDistribution
()
CK_TILE_DEVICE
static
constexpr
auto
MakeGammaBlockTileDistribution
()
{
{
...
...
include/ck_tile/ops/smoothquant.hpp
0 → 100644
View file @
e0594d08
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp"
#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp"
#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp"
#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp"
#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp
0 → 100644
View file @
e0594d08
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace
ck_tile
{
// host side args
struct
SmoothquantHostArgs
{
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_xscale
;
// [1, n], input, columnwise scale, fp32
void
*
p_yscale
;
// [m, 1], output, rowwise quant scale (amax / 127) of (p_x * p_xscale)
void
*
p_qy
;
// [m, n], output, p_x * p_xscale / p_yscale
index_t
m
;
index_t
n
;
index_t
stride
;
// row_stride
};
// TODO: Extract some type to wrapper class
template
<
typename
Pipeline_
>
struct
Smoothquant
{
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XScaleDataType
=
remove_cvref_t
<
typename
Problem
::
XScaleDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
QYDataType
=
remove_cvref_t
<
typename
Problem
::
QYDataType
>
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
bool
kPadM
=
false
;
// always no need to pad along M
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kTwoPass
=
Problem
::
kTwoPass
;
static
constexpr
index_t
ThreadPerWarp_N
=
Problem
::
BlockShape
::
ThreadPerWarp_N
;
static
constexpr
index_t
Vector_N
=
Problem
::
BlockShape
::
Vector_N
;
static
constexpr
index_t
Repeat_N
=
Problem
::
BlockShape
::
Repeat_N
;
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
struct
Kargs
{
const
void
*
p_x
;
const
void
*
p_xscale
;
void
*
p_yscale
;
void
*
p_qy
;
index_t
m
;
index_t
n
;
index_t
stride
;
// row_stride
};
using
Hargs
=
SmoothquantHostArgs
;
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
return
Kargs
{
hargs
.
p_x
,
hargs
.
p_xscale
,
hargs
.
p_yscale
,
hargs
.
p_qy
,
hargs
.
m
,
hargs
.
n
,
hargs
.
stride
};
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
{
return
dim3
(
integer_divide_ceil
(
hargs
.
m
,
Block_M
));
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockShape
::
BlockSize
;
}
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
// clang-format on
// in byte
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Pipeline
::
GetSmemSize
();
}
CK_TILE_HOST
static
std
::
string
GetName
()
{
// clang-format off
using
S_
=
typename
Problem
::
BlockShape
;
auto
surfix
=
[
&
]
()
{
std
::
string
n
;
if
(
kPadN
)
n
+=
"_pn"
;
if
(
kTwoPass
)
n
+=
"_2p"
;
return
n
;
}();
#define _SS_ std::string
#define _TS_ std::to_string
return
_SS_
(
"smoothquant_fwd_"
)
+
_SS_
(
t2s
<
XDataType
>::
name
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
_SS_
(
Pipeline
::
name
)
+
surfix
;
#undef _SS_
#undef _TS_
// clang-format on
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
const
auto
iM
=
get_block_id
()
*
Block_M
;
const
auto
x_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
const
auto
xscale_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XScaleDataType
*>
(
kargs
.
p_xscale
),
make_tuple
(
kargs
.
n
),
make_tuple
(
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_N
>
{}),
{
0
});
}();
auto
yscale_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
YScaleDataType
*>
(
kargs
.
p_yscale
),
make_tuple
(
kargs
.
m
),
make_tuple
(
1
),
number
<
1
>
{});
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{}),
sequence
<
kPadM
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{}),
{
iM
});
}();
auto
qy_window
=
[
&
]()
{
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
QYDataType
*>
(
kargs
.
p_qy
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
__shared__
char
smem
[
GetSmemSize
()];
Pipeline
{}(
x_window
,
xscale_window
,
yscale_window
,
qy_window
,
kargs
.
n
,
smem
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp
0 → 100644
View file @
e0594d08
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
namespace
ck_tile
{
struct
SmoothquantPipelineDefaultPolicy
{
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeXBlockTileDistribution
()
{
using
S
=
typename
Problem
::
BlockShape
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
S
::
Repeat_M
,
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
,
S
::
Vector_M
>
,
sequence
<
S
::
Repeat_N
,
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
,
S
::
Vector_N
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
2
,
2
>>
,
sequence
<
1
,
1
,
2
,
2
>
,
sequence
<
0
,
3
,
0
,
3
>>
{});
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeXScaleBlockTileDistribution
()
{
using
S
=
typename
Problem
::
BlockShape
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
>
,
tuple
<
sequence
<
S
::
Repeat_N
,
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
,
S
::
Vector_N
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
3
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2d
()
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockReduce2d
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dSync
()
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockReduce2dSync
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dCrossWarpSync
()
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockReduce2dCrossWarpSync
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
if
constexpr
(
Problem
::
kNeedCrossWarpSync
)
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
XDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
using
block_reduce2d
=
BlockReduce2d
<
P_
>
;
using
x_block_tile
=
decltype
(
make_static_distributed_tensor
<
typename
Problem
::
XDataType
>
(
MakeXBlockTileDistribution
<
Problem
>
()));
using
y_block_tile
=
decltype
(
block_reduce2d
::
template
MakeYBlockTile
<
x_block_tile
>());
return
GetBlockReduce2dCrossWarpSync
<
Problem
>
().
template
GetSmemSize
<
y_block_tile
>();
}
else
{
return
1
;
// zero size arrays are an extension
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp
0 → 100644
View file @
e0594d08
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
SmoothquantPipelineDefaultPolicy
>
struct
SmoothquantPipelineOnePass
{
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XScaleDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
QYDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
QYDataType
>
;
using
YScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockSmoothquantProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
return
"bpr_op"
;
// block per row
else
return
"wpr_op"
;
// warp per row
}();
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
XWindow
,
typename
XScaleWindow
,
typename
QYWindow
,
typename
YScaleWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
XScaleWindow
&
xscale_window_
,
YScaleWindow
&
yscale_window
,
QYWindow
&
qy_window
,
ck_tile
::
index_t
,
void
*
smem
)
const
{
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
xscale_window
=
make_tile_window
(
xscale_window_
,
Policy
::
template
MakeXScaleBlockTileDistribution
<
Problem
>());
auto
reduce_absmax_func
=
ReduceOp
::
AbsMax
{};
auto
reduce_max_func
=
ReduceOp
::
Max
{};
auto
block_reduce2d
=
Policy
::
template
GetBlockReduce2d
<
Problem
>();
auto
block_reduce2d_sync
=
Policy
::
template
GetBlockReduce2dSync
<
Problem
>();
auto
block_reduce2d_cross_warp_sync
=
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
const
auto
x
=
load_tile
(
x_window
);
const
auto
xscale
=
load_tile
(
xscale_window
);
auto
y
=
tile_elementwise_in
(
[
&
](
const
auto
&
a
,
const
auto
&
b
)
{
return
type_convert
<
ComputeDataType
>
(
a
)
*
type_convert
<
ComputeDataType
>
(
b
);
},
x
,
xscale
);
// compute absmax, cross-lane->cross-warp
auto
absmax
=
block_reduce2d
(
y
,
reduce_absmax_func
.
GetIdentityValue
<
ComputeDataType
>
(),
reduce_absmax_func
);
block_reduce2d_sync
(
absmax
,
reduce_max_func
);
block_reduce2d_cross_warp_sync
(
absmax
,
smem
,
reduce_max_func
);
// ex: yscale = absmax / 127 if int8
auto
yscale
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
v_
/
type_convert
<
ComputeDataType
>
(
numeric
<
QYDataType
>::
max
());
},
absmax
);
store_tile
(
yscale_window
,
cast_tile
<
YScaleDataType
>
(
yscale
));
// quantize y to qy
auto
qy
=
make_static_distributed_tensor
<
QYDataType
>
(
y
.
get_tile_distribution
());
sweep_tile
(
qy
,
[
&
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
auto
qy_
=
y
[
idx
]
/
yscale
[
i_idx
];
qy
(
idx
)
=
saturates
<
QYDataType
>
{}(
qy_
);
});
store_tile
(
qy_window
,
qy
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp
0 → 100644
View file @
e0594d08
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
// Y = X * XScale, QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale)
template
<
typename
XDataType_
,
typename
XScaleDataType_
,
typename
ComputeDataType_
,
typename
YScaleDataType_
,
typename
QYDataType_
,
typename
BlockShape_
,
bool
kPadN_
,
bool
kTwoPass_
>
struct
SmoothquantPipelineProblem
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
XScaleDataType
=
remove_cvref_t
<
XScaleDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
QYDataType
=
remove_cvref_t
<
QYDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
};
}
// namespace ck_tile
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp
0 → 100644
View file @
e0594d08
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
SmoothquantPipelineDefaultPolicy
>
struct
SmoothquantPipelineTwoPass
{
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XScaleDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
QYDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
QYDataType
>
;
using
YScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockSmoothquantProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
return
"bpr_tp"
;
// block per row
else
return
"wpr_tp"
;
// warp per row
}();
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
XWindow
,
typename
XScaleWindow
,
typename
QYWindow
,
typename
YScaleWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
XScaleWindow
&
xscale_window_
,
YScaleWindow
&
yscale_window
,
QYWindow
&
qy_window
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
{
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
xscale_window
=
make_tile_window
(
xscale_window_
,
Policy
::
template
MakeXScaleBlockTileDistribution
<
Problem
>());
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
index_t
num_n_tile_iteration
=
__builtin_amdgcn_readfirstlane
(
integer_divide_ceil
(
row_size
,
Block_N
));
auto
reduce_absmax_func
=
ReduceOp
::
AbsMax
{};
auto
reduce_max_func
=
ReduceOp
::
Max
{};
auto
block_reduce2d
=
Policy
::
template
GetBlockReduce2d
<
Problem
>();
auto
block_reduce2d_sync
=
Policy
::
template
GetBlockReduce2dSync
<
Problem
>();
auto
block_reduce2d_cross_warp_sync
=
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
using
XTensorType
=
decltype
(
cast_tile
<
ComputeDataType
>
(
load_tile
(
x_window
)));
auto
absmax
=
block_reduce2d
.
template
MakeYBlockTile
<
XTensorType
>();
set_tile
(
absmax
,
reduce_absmax_func
.
GetIdentityValue
<
ComputeDataType
>
());
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
const
auto
x
=
load_tile
(
x_window
);
const
auto
xscale
=
load_tile
(
xscale_window
);
const
auto
y
=
tile_elementwise_in
(
[
&
](
const
auto
&
a
,
const
auto
&
b
)
{
return
type_convert
<
ComputeDataType
>
(
a
)
*
type_convert
<
ComputeDataType
>
(
b
);
},
x
,
xscale
);
block_reduce2d
(
y
,
absmax
,
reduce_absmax_func
);
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
xscale_window
,
{
Block_N
});
}
// compute absmax, cross-lane->cross-warp
block_reduce2d_sync
(
absmax
,
reduce_max_func
);
block_reduce2d_cross_warp_sync
(
absmax
,
smem
,
reduce_max_func
);
// ex: yscale = absmax / 127 if int8
auto
yscale
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
v_
/
type_convert
<
ComputeDataType
>
(
numeric
<
QYDataType
>::
max
());
},
absmax
);
store_tile
(
yscale_window
,
cast_tile
<
YScaleDataType
>
(
yscale
));
// reverse read x to reuse cache
ck_tile
::
index_t
stride_to_right_most_window
=
row_size
%
Block_N
==
0
?
row_size
-
Block_N
:
row_size
-
row_size
%
Block_N
;
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
xscale_window
,
{
-
Block_N
});
move_tile_window
(
qy_window
,
{
0
,
stride_to_right_most_window
});
// recompute y and quantize y to qy
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
const
auto
x
=
load_tile
(
x_window
);
const
auto
xscale
=
load_tile
(
xscale_window
);
const
auto
y
=
tile_elementwise_in
(
[
&
](
const
auto
&
a
,
const
auto
&
b
)
{
return
type_convert
<
ComputeDataType
>
(
a
)
*
type_convert
<
ComputeDataType
>
(
b
);
},
x
,
xscale
);
auto
qy
=
make_static_distributed_tensor
<
QYDataType
>
(
y
.
get_tile_distribution
());
sweep_tile
(
qy
,
[
&
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
auto
qy_
=
y
[
idx
]
/
yscale
[
i_idx
];
qy
(
idx
)
=
saturates
<
QYDataType
>
{}(
qy_
);
});
store_tile
(
qy_window
,
qy
);
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
xscale_window
,
{
0
,
-
Block_N
});
move_tile_window
(
qy_window
,
{
0
,
-
Block_N
});
}
}
};
}
// namespace ck_tile
include/ck_tile/remod.py
View file @
e0594d08
from
datetime
import
datetime
import
pathlib
import
pathlib
from
pathlib
import
Path
from
pathlib
import
Path
import
subprocess
import
subprocess
...
@@ -8,8 +9,8 @@ NS = 'ck_tile'
...
@@ -8,8 +9,8 @@ NS = 'ck_tile'
OPS
=
'ops'
OPS
=
'ops'
OPS_COMMON
=
'common'
# common header will be duplicated into ops/* other module
OPS_COMMON
=
'common'
# common header will be duplicated into ops/* other module
HEADER_COMMON
=
"""// SPDX-License-Identifier: MIT
HEADER_COMMON
=
f
"""// SPDX-License-Identifier: MIT
// Copyright (c) 2018-
2024
, Advanced Micro Devices, Inc. All rights reserved.
\n
// Copyright (c) 2018-
{
datetime
.
now
().
year
}
, Advanced Micro Devices, Inc. All rights reserved.
\n
"""
"""
# aa/bb/cc/file.hpp -> (aa, bb, cc, file.hpp)
# aa/bb/cc/file.hpp -> (aa, bb, cc, file.hpp)
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp
View file @
e0594d08
...
@@ -131,6 +131,31 @@ using device_grouped_conv_fwd_xdl_f32_comp_instances = std::tuple<
...
@@ -131,6 +131,31 @@ using device_grouped_conv_fwd_xdl_f32_comp_instances = std::tuple<
// clang-format on
// clang-format on
>
;
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_int8_comp_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
256
,
32
,
8
,
8
,
16
,
16
,
8
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
// AGPR Spill when use permuted lds layout. so, use padding for these two.
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
int8_t
,
int8_t
,
int32_t
,
int8_t
,
DsLayout
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
256
,
128
,
128
,
64
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
>
// clang-format on
>
;
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment