Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7d50244e
Unverified
Commit
7d50244e
authored
Oct 31, 2024
by
Illia Silin
Committed by
GitHub
Oct 31, 2024
Browse files
Merge pull request #209 from ROCm/andriy/merge_from_public
Update develop branch from public repository
parents
f221c2b0
d51701d4
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1536 additions
and
40 deletions
+1536
-40
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
+67
-13
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
...layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
+6
-6
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
+65
-18
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
..._tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
+54
-0
include/ck_tile/ops/permute.hpp
include/ck_tile/ops/permute.hpp
+9
-0
include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp
...ude/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp
+169
-0
include/ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp
.../ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp
+28
-0
include/ck_tile/ops/reduce.hpp
include/ck_tile/ops/reduce.hpp
+4
-0
include/ck_tile/ops/reduce/block/block_reduce.hpp
include/ck_tile/ops/reduce/block/block_reduce.hpp
+175
-1
include/ck_tile/ops/reduce/block/block_reduce2d.hpp
include/ck_tile/ops/reduce/block/block_reduce2d.hpp
+274
-0
include/ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp
...k_tile/ops/reduce/block/block_reduce2d_default_policy.hpp
+79
-0
include/ck_tile/ops/reduce/block/block_reduce2d_problem.hpp
include/ck_tile/ops/reduce/block/block_reduce2d_problem.hpp
+18
-0
include/ck_tile/ops/rmsnorm2d.hpp
include/ck_tile/ops/rmsnorm2d.hpp
+13
-0
include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp
...ude/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp
+202
-0
include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp
include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp
+2
-2
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp
...norm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp
+94
-0
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp
...ps/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp
+101
-0
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp
...ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp
+36
-0
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp
...ps/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp
+131
-0
include/ck_tile/ops/softmax.hpp
include/ck_tile/ops/softmax.hpp
+9
-0
No files found.
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
View file @
7d50244e
...
...
@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include <string>
#include <type_traits>
...
...
@@ -24,14 +25,19 @@ struct Layernorm2dFwdPipelineOnePass
using
MeanDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
using
XResidualDataType
=
XDataType
;
using
YResidualDataType
=
XDataType
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kHasBeta
=
!
std
::
is_same_v
<
BetaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveMean
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveInvStd
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveMean
=
Problem
::
Traits
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveInvStd
=
Problem
::
Traits
::
kSaveMeanInvStd
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
...
...
@@ -46,20 +52,30 @@ struct Layernorm2dFwdPipelineOnePass
}
template
<
typename
XWindow
,
typename
XResidualWindow
,
typename
GammaWindow
,
typename
BetaWindow
,
typename
YWindow
,
typename
YResidualWindow
,
typename
MeanWindow
,
typename
InvStdWindow
>
typename
InvStdWindow
,
typename
XScaleWindow
,
typename
YScaleWindow
,
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
XResidualWindow
&
x_residual_window_
,
const
GammaWindow
&
gamma_window_
,
const
BetaWindow
&
beta_window_
,
YWindow
&
y_window
,
YWindow
&
y_window_
,
const
YResidualWindow
&
y_residual_window_
,
MeanWindow
&
mean_window
,
InvStdWindow
&
inv_std_window
,
const
XScaleWindow
&
x_scale_window_
,
YScaleWindow
&
y_scale_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
void
*
smem
,
Epilogue
)
const
{
const
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
...
...
@@ -67,8 +83,17 @@ struct Layernorm2dFwdPipelineOnePass
gamma_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
const
auto
beta_window
=
make_tile_window
(
beta_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
const
auto
x_residual_window
=
make_tile_window
(
x_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
y_residual_window
=
make_tile_window
(
y_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
const
auto
x_scale_window
=
make_tile_window
(
x_scale_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
x_scale
=
load_tile
(
x_scale_window
);
const
auto
x
=
load_tile
(
x_window
);
int
cur_count
=
0
;
int
max_count
=
block_tile_welford_calculate_max_count
<
typename
Problem
::
BlockShape
>
(
row_size
);
...
...
@@ -81,6 +106,18 @@ struct Layernorm2dFwdPipelineOnePass
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
beta
=
load_tile
(
beta_window
);
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
// compute x = x_resi + x
x
(
idx
)
=
type_convert
<
YResidualDataType
>
(
x_resi
(
idx
))
+
type_convert
<
YResidualDataType
>
(
x
(
idx
));
});
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
)
store_tile
(
y_residual_window
,
x
);
}
// compute welford each-thread->cross-lane->cross-warp
auto
[
mean
,
var
]
=
block_welford
(
x
,
cur_count
,
max_count
);
block_welford_sync
(
mean
,
var
,
cur_count
);
...
...
@@ -90,7 +127,7 @@ struct Layernorm2dFwdPipelineOnePass
// compute inv-std
auto
inv_std
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
)
+
epsilon
);
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
+
epsilon
)
)
;
},
var
);
...
...
@@ -100,8 +137,8 @@ struct Layernorm2dFwdPipelineOnePass
store_tile
(
inv_std_window
,
cast_tile
<
InvStdDataType
>
(
inv_std
));
// layernorm computation
auto
y
=
make_static_distributed_tensor
<
Y
DataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
auto
ln
=
make_static_distributed_tensor
<
Compute
DataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
ln
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
...
...
@@ -109,11 +146,28 @@ struct Layernorm2dFwdPipelineOnePass
const
auto
beta_
=
type_convert
<
ComputeDataType
>
(
beta
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
y_
=
(
x_
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
auto
ln_
=
(
x_
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
y
(
idx
)
=
type_convert
<
YDataType
>
(
y_
)
;
ln
(
idx
)
=
ln_
;
});
store_tile
(
y_window
,
y
);
if
constexpr
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
// smooth-quant pre-scale, then run rowwise-quant
sweep_tile
(
ln
,
[
&
](
auto
idx
)
{
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
xs_
=
type_convert
<
ComputeDataType
>
(
x_scale
[
j_idx
]);
ln
(
idx
)
=
ln
(
idx
)
*
xs_
;
});
}
if
constexpr
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
DYNAMIC_QUANT
||
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
Epilogue
{}(
y_window_
,
y_scale_window
,
ln
,
smem
);
}
else
Epilogue
{}(
y_window_
,
ln
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
View file @
7d50244e
...
...
@@ -14,10 +14,10 @@ template <typename XDataType_,
typename
YDataType_
,
typename
MeanDataType_
,
typename
InvStdDataType_
,
typename
XScaleDataType_
,
typename
YScaleDataType_
,
typename
BlockShape_
,
bool
kPadN_
,
bool
kSaveMeanInvStd_
,
bool
kTwoPass_
>
typename
Traits_
>
struct
Layernorm2dFwdPipelineProblem
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
...
...
@@ -27,14 +27,14 @@ struct Layernorm2dFwdPipelineProblem
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
XScaleDataType
=
remove_cvref_t
<
XScaleDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
View file @
7d50244e
...
...
@@ -24,20 +24,25 @@ struct Layernorm2dFwdPipelineTwoPass
using
MeanDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
using
XResidualDataType
=
XDataType
;
using
YResidualDataType
=
XDataType
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kHasBeta
=
!
std
::
is_same_v
<
BetaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveMean
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveInvStd
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveMean
=
Problem
::
Traits
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveInvStd
=
Problem
::
Traits
::
kSaveMeanInvStd
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
return
"bpr"
;
// block per row
return
"bpr
_2p
"
;
// block per row
else
return
"wpr"
;
// warp per row
return
"wpr
_2p
"
;
// warp per row
}();
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
...
...
@@ -46,20 +51,30 @@ struct Layernorm2dFwdPipelineTwoPass
}
template
<
typename
XWindow
,
typename
XResidualWindow
,
typename
GammaWindow
,
typename
BetaWindow
,
typename
YWindow
,
typename
YResidualWindow
,
typename
MeanWindow
,
typename
InvStdWindow
>
typename
InvStdWindow
,
typename
XScaleWindow
,
typename
YScaleWindow
,
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
XResidualWindow
&
x_residual_window_
,
const
GammaWindow
&
gamma_window_
,
const
BetaWindow
&
beta_window_
,
YWindow
&
y_window
,
const
YResidualWindow
&
y_residual_window_
,
MeanWindow
&
mean_window
,
InvStdWindow
&
inv_std_window
,
const
XScaleWindow
&
/*x_scale_window*/
,
YScaleWindow
&
/*y_scale_window*/
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
void
*
smem
,
Epilogue
)
const
{
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
...
...
@@ -67,6 +82,10 @@ struct Layernorm2dFwdPipelineTwoPass
gamma_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
auto
beta_window
=
make_tile_window
(
beta_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
auto
x_residual_window
=
make_tile_window
(
x_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
y_residual_window
=
make_tile_window
(
y_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
// Problem::BlockShape
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
...
...
@@ -93,9 +112,26 @@ struct Layernorm2dFwdPipelineTwoPass
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
const
auto
x
=
load_tile
(
x_window
);
block_welford
(
x
,
mean
,
var
,
cur_count
,
max_count
);
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
Block_N
});
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
// compute x = x_resi + x
x
(
idx
)
=
type_convert
<
YResidualDataType
>
(
x_resi
(
idx
))
+
type_convert
<
YResidualDataType
>
(
x
(
idx
));
});
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
store_tile
(
y_residual_window
,
x
);
move_tile_window
(
y_residual_window
,
{
0
,
Block_N
});
}
}
block_welford
(
x
,
mean
,
var
,
cur_count
,
max_count
);
}
block_welford_sync
(
mean
,
var
,
cur_count
);
...
...
@@ -105,7 +141,7 @@ struct Layernorm2dFwdPipelineTwoPass
// compute inv-std
auto
inv_std
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
)
+
epsilon
);
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
+
epsilon
)
)
;
},
var
);
...
...
@@ -118,9 +154,8 @@ struct Layernorm2dFwdPipelineTwoPass
ck_tile
::
index_t
stride_to_right_most_window
=
row_size
%
Block_N
==
0
?
row_size
-
Block_N
:
row_size
-
row_size
%
Block_N
;
// x_window.foo();
// gamma_window.foo();
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
move_tile_window
(
gamma_window
,
{
stride_to_right_most_window
});
move_tile_window
(
beta_window
,
{
stride_to_right_most_window
});
move_tile_window
(
y_window
,
{
0
,
stride_to_right_most_window
});
...
...
@@ -128,14 +163,24 @@ struct Layernorm2dFwdPipelineTwoPass
// layernorm computation
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
const
auto
x
=
load_tile
(
x_window
);
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
// compute x = x_resi + x
x
(
idx
)
=
type_convert
<
YResidualDataType
>
(
x_resi
(
idx
))
+
type_convert
<
YResidualDataType
>
(
x
(
idx
));
});
}
// load gamma/beta (TODO: support no gamma/beta?)
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
beta
=
load_tile
(
beta_window
);
auto
y
=
make_static_distributed_tensor
<
Y
DataType
>
(
x
.
get_tile_distribution
());
auto
ln
=
make_static_distributed_tensor
<
Compute
DataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
sweep_tile
(
ln
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
...
...
@@ -143,14 +188,16 @@ struct Layernorm2dFwdPipelineTwoPass
const
auto
beta_
=
type_convert
<
ComputeDataType
>
(
beta
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
y_
=
(
x_
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
auto
ln_
=
(
x_
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
y
(
idx
)
=
type_convert
<
YDataType
>
(
y_
)
;
ln
(
idx
)
=
ln_
;
});
store_tile
(
y_window
,
y
);
static_assert
(
kFusedQuant
!=
Layernorm2dFusedQuantEnum
::
DYNAMIC_QUANT
);
Epilogue
{}(
y_window
,
ln
);
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
move_tile_window
(
gamma_window
,
{
-
Block_N
});
move_tile_window
(
beta_window
,
{
-
Block_N
});
move_tile_window
(
y_window
,
{
0
,
-
Block_N
});
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
enum
class
Layernorm2dFusedAddEnum
{
NO_ADD
=
0
,
// fused add before layernorm and store result to global
PRE_ADD_STORE
=
1
,
// fused add before layernorm, but not store result
PRE_ADD
=
2
,
};
// clang-format off
template
<
Layernorm2dFusedAddEnum
>
struct
Layernorm2dFusedAddEnumName
;
template
<
>
struct
Layernorm2dFusedAddEnumName
<
Layernorm2dFusedAddEnum
::
NO_ADD
>
{
static
constexpr
const
char
*
name
=
"no"
;
};
template
<
>
struct
Layernorm2dFusedAddEnumName
<
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
>
{
static
constexpr
const
char
*
name
=
"pras"
;
};
template
<
>
struct
Layernorm2dFusedAddEnumName
<
Layernorm2dFusedAddEnum
::
PRE_ADD
>
{
static
constexpr
const
char
*
name
=
"pra"
;
};
// clang-format on
enum
class
Layernorm2dFusedQuantEnum
{
NO_SWEEP
=
0
,
SMOOTH_DYNAMIC_QUANT
=
1
,
// smooth oulier + rowwise quant, need input x-scale and store y_scale
DYNAMIC_QUANT
=
2
,
// rowwise quant, store out a y-scale
};
// clang-format off
template
<
Layernorm2dFusedQuantEnum
>
struct
Layernorm2dFusedQuantEnumName
;
template
<
>
struct
Layernorm2dFusedQuantEnumName
<
Layernorm2dFusedQuantEnum
::
NO_SWEEP
>
{
static
constexpr
const
char
*
name
=
"no"
;
};
template
<
>
struct
Layernorm2dFusedQuantEnumName
<
Layernorm2dFusedQuantEnum
::
DYNAMIC_QUANT
>
{
static
constexpr
const
char
*
name
=
"dqt"
;
};
template
<
>
struct
Layernorm2dFusedQuantEnumName
<
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
>
{
static
constexpr
const
char
*
name
=
"smdqt"
;
};
// clang-format on
template
<
bool
kPadN_
,
bool
kSaveMeanInvStd_
,
bool
kTwoPass_
,
Layernorm2dFusedAddEnum
kFusedAdd_
,
Layernorm2dFusedQuantEnum
kFusedQuant_
>
struct
Layernorm2dFwdTraits
{
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
static
constexpr
Layernorm2dFusedAddEnum
kFusedAdd
=
kFusedAdd_
;
static
constexpr
Layernorm2dFusedQuantEnum
kFusedQuant
=
kFusedQuant_
;
};
}
// namespace ck_tile
include/ck_tile/ops/permute.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp"
#include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
// #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
namespace
ck_tile
{
/* independent host side argument, no template
*/
struct
GenericPermuteHostArgs
{
static
constexpr
index_t
kMaxRanks
=
8
;
// TODO: hardcoded
const
void
*
p_src
;
void
*
p_dst
;
index_t
rank
;
index_t
shape
[
kMaxRanks
];
// input shape
index_t
perm
[
kMaxRanks
];
// permute index
};
/*
simulate torch.permute:
x_ = x_.view(x.shape[0],
x.shape[1]//16, 16,
x.shape[2]//32, 4, 8)
x_ = x_.permute(0,1,3,4,2,5)
x_ = x_.contiguous()
x_ = x_.view(x.shape[0], x.shape[1], x.shape[2]);//
this kernel is supposed not to be performant(just OK), with functional support up to kMaxRanks
dim of permutation, with a single kernel
*/
template
<
typename
Problem_
>
struct
GenericPermute
{
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
DataType
=
remove_cvref_t
<
typename
Problem
::
DataType
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kMaxRanks
=
Problem
::
kMaxRanks
;
static
constexpr
bool
KeepLastDim
=
Problem
::
KeepLastDim
;
struct
__attribute__
((
packed
))
Kargs
{
const
void
*
p_src
;
void
*
p_dst
;
// index_t rank;
index_t
num_elements
;
index_t
perm_length
[
kMaxRanks
];
// tensor length after permutation
index_t
perm_stride
[
kMaxRanks
];
// tensor stride after permutation
};
CK_TILE_HOST
static
constexpr
index_t
TotalElements
(
const
GenericPermuteHostArgs
&
h
)
{
index_t
n
=
1
;
for
(
auto
i
=
0
;
i
<
h
.
rank
;
i
++
)
{
n
*=
h
.
shape
[
i
];
}
return
n
;
}
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
GenericPermuteHostArgs
&
h
)
{
Kargs
a
;
a
.
p_src
=
h
.
p_src
;
a
.
p_dst
=
h
.
p_dst
;
// assert rank <= kMaxRanks
index_t
i
=
0
;
index_t
perm
[
kMaxRanks
];
index_t
x_shape
[
kMaxRanks
];
index_t
x_stride
[
kMaxRanks
];
// index_t perm_length[kMaxRanks];
for
(;
i
<
h
.
rank
;
i
++
)
{
x_shape
[
i
]
=
h
.
shape
[
i
];
perm
[
i
]
=
h
.
perm
[
i
];
}
for
(;
i
<
kMaxRanks
;
i
++
)
{
x_shape
[
i
]
=
1
;
perm
[
i
]
=
i
;
// will index to len = 1
}
index_t
stride
=
1
;
for
(
index_t
j
=
kMaxRanks
-
1
;
j
>=
0
;
j
--
)
{
x_stride
[
j
]
=
stride
;
stride
*=
x_shape
[
j
];
}
for
(
index_t
j
=
0
;
j
<
kMaxRanks
;
j
++
)
{
a
.
perm_length
[
j
]
=
x_shape
[
perm
[
j
]];
a
.
perm_stride
[
j
]
=
x_stride
[
perm
[
j
]];
}
a
.
num_elements
=
TotalElements
(
h
);
return
a
;
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
GenericPermuteHostArgs
h
)
{
auto
total
=
TotalElements
(
h
);
auto
grids
=
dim3
((
total
+
BlockSize
()
-
1
)
/
BlockSize
());
// printf("### total:%d, grids:%dx%dx%d\n", total, );
return
grids
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
BlockSize
()
{
return
Problem
::
kBlockSize
;
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
index_t
id
=
blockIdx
.
x
*
BlockSize
()
+
threadIdx
.
x
;
if
(
id
>=
kargs
.
num_elements
)
return
;
const
auto
perm_length
=
generate_tuple
([
&
](
auto
I
)
{
return
kargs
.
perm_length
[
I
];
},
number
<
kMaxRanks
>
{});
const
auto
perm_stride
=
generate_tuple
([
&
](
auto
I
)
{
return
kargs
.
perm_stride
[
I
];
},
number
<
kMaxRanks
>
{});
const
DataType
*
p_src
=
reinterpret_cast
<
const
DataType
*>
(
kargs
.
p_src
);
DataType
*
p_dst
=
reinterpret_cast
<
DataType
*>
(
kargs
.
p_dst
);
const
auto
src_view_0
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_src
,
perm_length
,
perm_stride
,
number
<
1
>
{},
number
<
1
>
{});
const
auto
src_view
=
transform_tensor_view
(
src_view_0
,
make_tuple
(
make_merge_transform
(
perm_length
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
kMaxRanks
,
1
>::
type
{}),
make_tuple
(
sequence
<
0
>
{}));
auto
dst_view_0
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
p_dst
,
perm_length
,
number
<
1
>
{});
auto
dst_view
=
transform_tensor_view
(
dst_view_0
,
make_tuple
(
make_merge_transform
(
perm_length
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
kMaxRanks
,
1
>::
type
{}),
make_tuple
(
sequence
<
0
>
{}));
// TODO: hard code to vector 1
using
vector_t
=
thread_buffer
<
DataType
,
1
>
;
const
auto
src_coord
=
make_tensor_coordinate
(
src_view
.
get_tensor_descriptor
(),
array
<
index_t
,
1
>
{
id
});
const
auto
dst_coord
=
make_tensor_coordinate
(
dst_view
.
get_tensor_descriptor
(),
array
<
index_t
,
1
>
{
id
});
// printf("src id:%d, os:%d\n", id, src_coord.get_offset());
// printf("dst id:%d, os:%d\n", id, dst_coord.get_offset());
const
vector_t
x
=
src_view
.
template
get_vectorized_elements
<
vector_t
>(
src_coord
,
0
);
dst_view
.
template
set_vectorized_elements
<
vector_t
>(
dst_coord
,
0
,
x
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
typename
DataType_
,
index_t
kBlockSize_
=
256
,
index_t
kMaxRanks_
=
8
,
bool
KeepLastDim_
=
false
>
struct
GenericPermuteProblem
{
using
DataType
=
remove_cvref_t
<
DataType_
>
;
static
constexpr
index_t
kBlockSize
=
kBlockSize_
;
static
constexpr
index_t
kMaxRanks
=
kMaxRanks_
;
/* KeepLastDim:
* if last dim keep the same? this can help enable vector load
* permute(0, 2, 4, 1, 3, 5) -> true
* permute(0, 3, 2, 1) -> false
*/
static
constexpr
bool
KeepLastDim
=
KeepLastDim_
;
// TODO: not used(?)
};
}
// namespace ck_tile
include/ck_tile/ops/reduce.hpp
View file @
7d50244e
...
...
@@ -4,4 +4,8 @@
#pragma once
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/reduce/block/block_reduce.hpp
View file @
7d50244e
...
...
@@ -4,9 +4,15 @@
#pragma once
#include "ck_tile/core.hpp"
#include <tuple>
// This file is not support cross warp reduce
namespace
ck_tile
{
/*
* TODO: block_tile_reduce_sync() currently has a limitation
* Y dim must have at least one dim not been reduced
*/
// synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
template
<
typename
AccDistributedTensor_
,
typename
ReduceFunc
,
bool
WithBroadcast
=
true
>
CK_TILE_DEVICE
void
block_tile_reduce_sync
(
AccDistributedTensor_
&
acc_tensor
,
...
...
@@ -22,7 +28,7 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
constexpr
index_t
idim_p_lane
=
NDimP
-
1
;
const
auto
ps_idx
=
make_array
<
index_t
>
(
get_block_id
(),
get_lane_id
());
const
auto
ps_idx
=
detail
::
get_partition_index
(
acc_tensor
.
get_tile_distribution
());
const
auto
rs_idx
=
acc_tensor
.
get_tile_distribution
().
calculate_rs_index_from_ps_index
(
ps_idx
);
constexpr
index_t
thread_buf_size
=
AccDistributedTensor_
::
get_thread_buffer_size
();
...
...
@@ -104,6 +110,65 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
});
}
/*
* this version is faster, using xor to do reduce, no need broadcast anymore
* TODO: the limitation is to-be-reduced P dim can only mapping to one R dim?
*/
template
<
typename
AccDistributedTensor_
,
typename
ReduceFunc
>
CK_TILE_DEVICE
void
block_tile_reduce_xor_sync
(
AccDistributedTensor_
&
acc_tensor
,
const
ReduceFunc
&
reduce_func
)
{
using
Dstr
=
typename
AccDistributedTensor_
::
StaticTileDistribution
;
using
DstrEncode
=
typename
Dstr
::
DstrEncode
;
using
DstrEncodeDetail
=
typename
DstrEncode
::
detail
;
constexpr
index_t
NDimP
=
Dstr
::
get_num_of_dimension_p
();
constexpr
index_t
NDimR
=
Dstr
::
get_num_of_dimension_r
();
constexpr
index_t
idim_p_lane
=
NDimP
-
1
;
constexpr
index_t
thread_buf_size
=
AccDistributedTensor_
::
get_thread_buffer_size
();
// loop over thread data
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i
)
{
auto
v_local
=
acc_tensor
.
get_thread_buffer
()[
i
];
// cross-lane reduce for replication
// only reduce on R dimension correspond to lane
// (lane id maps to this R dimension)
static_for
<
0
,
NDimR
,
1
>
{}([
&
](
auto
idim_r
)
{
// FIXME: nasty to use does_p_own_r_
if
constexpr
(
DstrEncodeDetail
::
does_p_own_r_
[
idim_p_lane
][
idim_r
])
{
constexpr
index_t
r_length
=
DstrEncode
::
rs_lengths_
[
idim_r
];
constexpr
index_t
lid_over_rid_derivative
=
DstrEncodeDetail
::
ps_over_rs_derivative_
[
idim_p_lane
][
idim_r
];
static_assert
(
is_power_of_two_integer
(
r_length
),
"wrong! only support power of 2 reduction"
);
constexpr
index_t
nstage
=
integer_log2_floor
(
r_length
);
// reduction sweep forward
static_for
<
0
,
nstage
,
1
>
{}([
&
](
auto
istage
)
{
// xor
index_t
src_lane
=
__lane_id
()
^
(
number
<
lid_over_rid_derivative
<<
istage
.
value
>
{}.
value
);
// pull data from remote lane
const
auto
v_remote
=
warp_shuffle
(
v_local
,
src_lane
);
// reduce
v_local
=
reduce_func
(
v_local
,
v_remote
);
});
}
});
acc_tensor
.
get_thread_buffer
()(
i
)
=
v_local
;
});
}
// FIXME: this is for 2D to 1D reduce only, need to support n-D
template
<
typename
AccDistributedTensor_
,
typename
InDistributedTensor_
,
...
...
@@ -175,6 +240,10 @@ CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor,
#endif
}
/*
* TODO: block_tile_reduce() currently has a limitation
* Y dim must have at least one dim not been reduced
*/
template
<
typename
AccDataType_
,
typename
InDistributedTensor_
,
index_t
...
InReduceDims
,
...
...
@@ -208,4 +277,109 @@ CK_TILE_DEVICE auto block_tile_reduce(const InDistributedTensor_& in_tensor,
return
acc_tensor
;
}
// this version only support 2D->1D reduce (reduce-dim=seq<0, 1>)
// this version only support in/acc/out datatypes are the same
// this version will call thread/warp+sync in one function call
//
template
<
typename
InDistributedTensor_
>
struct
BlockReduce2D
{
using
InDistributedTensor
=
remove_cvref_t
<
InDistributedTensor_
>
;
using
InDataType
=
typename
InDistributedTensor
::
DataType
;
CK_TILE_HOST_DEVICE
BlockReduce2D
(
const
InDistributedTensor
&
t_
,
const
InDataType
&
reduce_init_
)
:
t
(
t_
),
reduce_init
(
reduce_init_
)
{
}
CK_TILE_HOST_DEVICE
constexpr
auto
MakeDstBlockTile
()
const
{
using
ReduceDim
=
sequence
<
1
>
;
// hard coded
constexpr
auto
acc_dstr
=
make_static_tile_distribution
(
ck_tile
::
detail
::
make_reduce_tile_distribution_encoding
(
InDistributedTensor
::
get_tile_distribution
()
.
get_static_tile_distribution_encoding
(),
ReduceDim
{}));
auto
dst_
=
make_static_distributed_tensor
<
InDataType
>
(
acc_dstr
);
// init acc_tensor
tile_elementwise_inout
([
&
](
auto
&
x_
)
{
x_
=
type_convert
<
InDataType
>
(
reduce_init
);
},
dst_
);
return
dst_
;
}
// return number of pixels each lane need to reduce
CK_TILE_HOST_DEVICE
constexpr
auto
get_reduce_length_y
()
const
{
constexpr
auto
spans
=
InDistributedTensor
::
get_distributed_spans
();
}
// Here ReducePacksPerXDim is not the same meaning as that in static_uford/sweep_tile_uspan
// this is number of packs along the X-dim. We need to compute the Unpacks along the Y dim
// internally
// For simplicity, we just support along the row dimension, ReducePacksPerXDim is always 2
// element , and the first element is always ignored For simplicity, will always try from
// right-to-left to find alone which Y dim to split
template
<
typename
ReduceFunc
,
typename
ReduceSyncFunc
,
typename
ReducePacksPerXDim
=
uniform_sequence_gen_t
<
2
,
1
>
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
ReduceFunc
&
reduce_func
,
const
ReduceSyncFunc
&
reduce_sync_func
,
ReducePacksPerXDim
=
{})
const
{
constexpr
auto
spans
=
InDistributedTensor
::
get_distributed_spans
();
constexpr
auto
row_y_unpacks
=
[
&
]()
{
constexpr
auto
row_y_lengths
=
typename
decltype
(
spans
[
number
<
1
>
{}])
::
Impl
{};
constexpr
auto
row_y_size
=
reduce_on_sequence
(
row_y_lengths
,
multiplies
{},
number
<
1
>
{});
constexpr
auto
row_y_packs
=
ReducePacksPerXDim
{}.
at
(
number
<
1
>
{});
static_assert
(
row_y_size
%
row_y_packs
==
0
);
constexpr
auto
row_y_slice_size
=
row_y_size
/
row_y_packs
;
constexpr
auto
slice_info
=
slice_sequence
(
row_y_lengths
,
number
<
row_y_slice_size
>
{});
constexpr
auto
unpacks
=
slice_info
[
number
<
1
>
{}];
return
unpacks
;
}();
auto
acc_tensor
=
MakeDstBlockTile
();
// in-thread reduction
// FIXME: hard coded to be 2D to 1D reduction
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
dstr_idx_i0
)
{
constexpr
auto
acc_dstr_idx
=
make_tuple
(
dstr_idx_i0
);
auto
acc
=
acc_tensor
[
acc_dstr_idx
];
sweep_tile_uspan
(
spans
[
number
<
1
>
{}],
[
&
](
auto
...
dstr_idx_i1
)
{
acc
=
reduce_func
(
acc
,
t
[
make_tuple
(
dstr_idx_i0
,
dstr_idx_i1
)]...);
},
row_y_unpacks
);
acc_tensor
(
acc_dstr_idx
)
=
acc
;
});
// TODO: always use xor to do cross-lane reduce
block_tile_reduce_xor_sync
(
acc_tensor
,
reduce_sync_func
);
return
acc_tensor
;
}
template
<
typename
ReduceFunc
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
ReduceFunc
&
reduce_func
)
const
{
return
operator
()(
reduce_func
,
reduce_func
);
}
InDistributedTensor
t
;
InDataType
reduce_init
;
};
// deduction guide
template
<
typename
T
>
CK_TILE_HOST_DEVICE_EXTERN
BlockReduce2D
(
const
T
&
,
const
typename
T
::
DataType
&
)
->
BlockReduce2D
<
T
>
;
}
// namespace ck_tile
include/ck_tile/ops/reduce/block/block_reduce2d.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
BlockReduce2d
{
// in-thread reduction
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
XDataType
=
typename
Problem
::
XDataType
;
using
ComputeDataType
=
typename
Problem
::
ComputeDataType
;
CK_TILE_DEVICE
constexpr
BlockReduce2d
()
{}
template
<
typename
XDistributedTensor_
,
typename
YDistributedTensor_
,
typename
ReduceFunc
,
typename
ReducePacksPerXDim
=
uniform_sequence_gen_t
<
2
,
1
>
>
CK_TILE_DEVICE
void
operator
()(
const
XDistributedTensor_
&
x_tensor
,
YDistributedTensor_
&
y_tensor
,
const
ReduceFunc
&
reduce_func
,
ReducePacksPerXDim
=
{})
{
sweep_tile
<
XDistributedTensor_
>
(
[
&
](
auto
...
idx_
)
{
constexpr
auto
idx_0
=
make_tuple
(
make_tuple
(
idx_
[
number
<
0
>
{}]...)[
number
<
0
>
{}]);
y_tensor
(
idx_0
)
=
reduce_func
(
y_tensor
(
idx_0
),
x_tensor
[
idx_
]...);
},
ReducePacksPerXDim
{});
#if 0
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
constexpr auto spans = XDistributedTensor_::get_distributed_spans();
// FIXME: hard coded to reduce 2nd axis
sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
constexpr auto y_dstr_idx = make_tuple(dstr_idx_i0);
auto y = y_tensor[y_dstr_idx];
sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
const auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
y = reduce_func(y, x);
});
y_tensor(y_dstr_idx) = y;
});
#endif
}
template
<
typename
XDistributedTensor_
>
CK_TILE_DEVICE
static
auto
MakeYBlockTile
()
{
static_assert
(
std
::
is_same_v
<
XDataType
,
typename
XDistributedTensor_
::
DataType
>
,
"wrong!"
);
// FIXME: hard coded to reduce 2nd axis
constexpr
auto
reduce_dims
=
sequence
<
1
>
{};
constexpr
auto
dstr
=
make_static_tile_distribution
(
detail
::
make_reduce_tile_distribution_encoding
(
XDistributedTensor_
::
get_tile_distribution
()
.
get_static_tile_distribution_encoding
(),
reduce_dims
));
auto
tensor
=
make_static_distributed_tensor
<
ComputeDataType
>
(
dstr
);
return
tensor
;
}
template
<
typename
XDistributedTensor_
,
typename
ReduceFunc
,
typename
ReducePacksPerXDim
=
uniform_sequence_gen_t
<
2
,
1
>
>
CK_TILE_DEVICE
auto
operator
()(
const
XDistributedTensor_
&
x_tensor
,
const
ComputeDataType
&
reduce_init
,
const
ReduceFunc
&
reduce_func
,
ReducePacksPerXDim
=
{})
{
auto
y_tensor
=
MakeYBlockTile
<
XDistributedTensor_
>
();
set_tile
(
y_tensor
,
reduce_init
);
(
*
this
)(
x_tensor
,
y_tensor
,
reduce_func
,
ReducePacksPerXDim
{});
return
y_tensor
;
}
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
BlockReduce2dSync
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
template
<
typename
YDistributedTensor_
,
typename
ReduceFunc
>
CK_TILE_DEVICE
void
operator
()(
YDistributedTensor_
&
y_tensor
,
const
ReduceFunc
&
reduce_func
)
{
using
Dstr
=
typename
YDistributedTensor_
::
StaticTileDistribution
;
using
DstrEncode
=
typename
Dstr
::
DstrEncode
;
using
DstrEncodeDetail
=
typename
DstrEncode
::
detail
;
constexpr
index_t
NDimP
=
Dstr
::
get_num_of_dimension_p
();
constexpr
index_t
NDimR
=
Dstr
::
get_num_of_dimension_r
();
constexpr
index_t
idim_p_lane
=
NDimP
-
1
;
// const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
// const auto rs_idx =
// y_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
constexpr
index_t
thread_buf_size
=
YDistributedTensor_
::
get_thread_buffer_size
();
// loop over thread data
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i
)
{
auto
v_local
=
y_tensor
.
get_thread_buffer
()[
i
];
// cross-lane reduce for replication
// only reduce on R dimension correspond to lane
// (lane id maps to this R dimension)
static_for
<
0
,
NDimR
,
1
>
{}([
&
](
auto
idim_r
)
{
// FIXME: nasty to use does_p_own_r_
if
constexpr
(
DstrEncodeDetail
::
does_p_own_r_
[
idim_p_lane
][
idim_r
])
{
constexpr
index_t
r_length
=
DstrEncode
::
rs_lengths_
[
idim_r
];
constexpr
index_t
lid_over_rid_derivative
=
DstrEncodeDetail
::
ps_over_rs_derivative_
[
idim_p_lane
][
idim_r
];
static_assert
(
is_power_of_two_integer
(
r_length
),
"wrong! only support power of 2 reduction"
);
constexpr
index_t
nstage
=
integer_log2_floor
(
r_length
);
// reduction sweep forward
static_for
<
0
,
nstage
,
1
>
{}([
&
](
auto
istage
)
{
// xor
index_t
src_lane
=
(
__lane_id
())
^
(
number
<
lid_over_rid_derivative
<<
istage
.
value
>
{}.
value
);
// pull data from remote lane
const
auto
v_remote
=
warp_shuffle
(
v_local
,
src_lane
);
// reduce
v_local
=
reduce_func
(
v_local
,
v_remote
);
});
}
});
// TODO - Do we need to broadcast to other lane?
y_tensor
.
get_thread_buffer
()(
i
)
=
v_local
;
});
}
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
BlockReduce2dCrossWarpSync
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
BlockShape
=
typename
Problem
::
BlockShape
;
template
<
typename
YDistributedTensor_
>
CK_TILE_DEVICE
static
constexpr
index_t
GetReduceWarps
()
{
constexpr
index_t
num_reduce_warps
=
[
&
]()
{
using
Dstr
=
typename
YDistributedTensor_
::
StaticTileDistribution
;
using
DstrEncode
=
typename
Dstr
::
DstrEncode
;
using
DstrEncodeDetail
=
typename
DstrEncode
::
detail
;
constexpr
index_t
NDimR
=
Dstr
::
get_num_of_dimension_r
();
constexpr
index_t
idim_p_warp
=
0
;
index_t
len_
=
1
;
static_for
<
0
,
NDimR
,
1
>
{}([
&
](
auto
idim_r
)
{
if
constexpr
(
DstrEncodeDetail
::
does_p_own_r_
[
idim_p_warp
][
idim_r
])
{
constexpr
index_t
r_length
=
DstrEncode
::
rs_lengths_
[
idim_r
];
len_
*=
r_length
;
}
});
return
len_
;
}();
return
num_reduce_warps
;
}
// return in byte
template
<
typename
YDistributedTensor_
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
using
DataType
=
typename
YDistributedTensor_
::
DataType
;
// constexpr auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
constexpr
index_t
thread_buf_size
=
YDistributedTensor_
::
get_thread_buffer_size
();
// we need to store all data from every wave into smem
// e.g. 2x2 reduce along N
// -------------> reduce N
// | w0 | w1 | ___> | w01 |
// | w2 | w3 | | w23 |
//
// -> store data from every wave into LDS
//
//
// -------------> reduce N
// | w0 | w1 | w2 | w3 | -----> | w0123 |
//
// -> also store data from every wave into LDS
constexpr
index_t
num_warps
=
BlockShape
::
BlockSize
/
warpSize
;
return
num_warps
*
thread_buf_size
*
sizeof
(
DataType
);
}
template
<
typename
YDistributedTensor_
,
typename
ReduceFunc
>
CK_TILE_DEVICE
void
operator
()(
YDistributedTensor_
&
y_tensor
,
void
*
smem
,
const
ReduceFunc
&
reduce_func
)
{
using
DataType
=
typename
YDistributedTensor_
::
DataType
;
constexpr
index_t
thread_buf_size
=
YDistributedTensor_
::
get_thread_buffer_size
();
DataType
*
smem_ptr
=
reinterpret_cast
<
DataType
*>
(
smem
);
const
index_t
lane_id
=
get_lane_id
();
const
index_t
warp_id
=
get_warp_id
();
constexpr
auto
num_reduce_warps
=
GetReduceWarps
<
YDistributedTensor_
>
();
constexpr
index_t
num_warps
=
BlockShape
::
BlockSize
/
warpSize
;
const
index_t
smem_offset
=
warp_id
;
// skip if nonthing to do
if
constexpr
(
num_reduce_warps
==
1
)
return
;
// store into smem only for lane-0 within one warp
if
(
lane_id
==
0
)
{
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i
)
{
smem_ptr
[
smem_offset
+
i
*
num_warps
]
=
y_tensor
.
get_thread_buffer
()[
i
];
});
}
block_sync_lds
();
// load from smem. here we let everythread to do compute :)
index_t
local_warp_id
=
warp_id
/
num_reduce_warps
;
index_t
local_smem_os
=
local_warp_id
*
num_reduce_warps
;
DataType
all_scratch
[
thread_buf_size
*
num_reduce_warps
];
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i_0
)
{
static_for
<
0
,
num_reduce_warps
,
1
>
{}([
&
](
auto
i_1
)
{
all_scratch
[
i_0
*
num_reduce_warps
+
i_1
]
=
smem_ptr
[
i_0
*
num_warps
+
local_smem_os
+
i_1
];
});
});
block_sync_lds
();
// TODO: we don't need sync here
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i_0
)
{
// TODO: use descriptor for this
auto
v_local
=
all_scratch
[
i_0
*
num_reduce_warps
];
// further reduce mean/var
static_for
<
0
,
num_reduce_warps
-
1
,
1
>
{}([
&
](
auto
i_1_n1
)
{
constexpr
auto
i_1
=
number
<
i_1_n1
+
1
>
{};
const
DataType
v_remote
=
all_scratch
[
i_0
*
num_reduce_warps
+
i_1
];
// reduce
v_local
=
reduce_func
(
v_local
,
v_remote
);
});
y_tensor
.
get_thread_buffer
()(
i_0
)
=
v_local
;
});
}
};
}
// namespace ck_tile
include/ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
namespace
ck_tile
{
struct
BlockReduce2dDefaultPolicy
{
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeXBlockTileDistribution
()
{
using
S
=
typename
Problem
::
BlockShape
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
S
::
Repeat_M
,
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
,
S
::
Vector_M
>
,
sequence
<
S
::
Repeat_N
,
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
,
S
::
Vector_N
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
2
,
2
>>
,
sequence
<
1
,
1
,
2
,
2
>
,
sequence
<
0
,
3
,
0
,
3
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2d
()
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
XDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockReduce2d
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dSync
()
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
XDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockReduce2dSync
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dCrossWarpSync
()
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
XDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockReduce2dCrossWarpSync
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
if
constexpr
(
Problem
::
kNeedCrossWarpSync
)
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
XDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
using
block_reduce2d
=
BlockReduce2d
<
P_
>
;
using
x_block_tile
=
decltype
(
make_static_distributed_tensor
<
typename
Problem
::
XDataType
>
(
MakeXBlockTileDistribution
<
Problem
>
()));
using
y_block_tile
=
decltype
(
block_reduce2d
::
template
MakeYBlockTile
<
x_block_tile
>());
return
GetBlockReduce2dCrossWarpSync
<
Problem
>
().
template
GetSmemSize
<
y_block_tile
>();
}
else
{
return
1
;
// zero size arrays are an extension
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/reduce/block/block_reduce2d_problem.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
XDataType_
,
typename
ComputeDataType_
,
typename
BlockShape_
>
struct
BlockReduce2dProblem
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
};
}
// namespace ck_tile
include/ck_tile/ops/rmsnorm2d.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp"
#include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace
ck_tile
{
// host side args
struct
Rmsnorm2dFwdHostArgs
{
const
void
*
p_x
;
const
void
*
p_gamma
;
void
*
p_y
;
void
*
p_invRms
;
float
epsilon
;
index_t
m
;
index_t
n
;
index_t
stride
;
// row_stride
};
// TODO: Extract some type to wrapper class
template
<
typename
Pipeline_
>
struct
Rmsnorm2dFwd
{
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
GammaDataType
=
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YDataType
=
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
InvRmsDataType
=
remove_cvref_t
<
typename
Problem
::
InvRmsDataType
>
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
null_type
>
;
static
constexpr
bool
kSaveInvRms
=
Problem
::
kSaveInvRms
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
bool
kPadM
=
false
;
// always no need to pad along M
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kTwoPass
=
Problem
::
kTwoPass
;
static
constexpr
index_t
ThreadPerWarp_N
=
Problem
::
BlockShape
::
ThreadPerWarp_N
;
static
constexpr
index_t
Vector_N
=
Problem
::
BlockShape
::
Vector_N
;
static
constexpr
index_t
Repeat_N
=
Problem
::
BlockShape
::
Repeat_N
;
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
struct
Kargs
{
const
void
*
p_x
;
const
void
*
p_gamma
;
void
*
p_y
;
void
*
p_invRms
;
float
epsilon
;
index_t
m
;
index_t
n
;
index_t
stride
;
// row_stride
};
using
Hargs
=
Rmsnorm2dFwdHostArgs
;
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
return
Kargs
{
hargs
.
p_x
,
hargs
.
p_gamma
,
hargs
.
p_y
,
hargs
.
p_invRms
,
hargs
.
epsilon
,
hargs
.
m
,
hargs
.
n
,
hargs
.
stride
};
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
{
return
(
hargs
.
m
+
Block_M
-
1
)
/
Block_M
;
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockShape
::
BlockSize
;
}
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
// clang-format on
// in byte
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Pipeline
::
GetSmemSize
();
}
CK_TILE_HOST
static
std
::
string
GetName
()
{
// clang-format off
using
S_
=
typename
Problem
::
BlockShape
;
auto
surfix
=
[
&
]
()
{
std
::
string
n
;
if
(
kPadN
)
n
+=
"_pn"
;
if
(
kSaveInvRms
)
n
+=
"_rms"
;
if
(
kTwoPass
)
n
+=
"_2p"
;
return
n
;
}();
#define _SS_ std::string
#define _TS_ std::to_string
return
_SS_
(
"rmsnorm2d_fwd_"
)
+
_SS_
(
t2s
<
XDataType
>::
name
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
_SS_
(
Pipeline
::
name
)
+
surfix
;
#undef _SS_
#undef _TS_
// clang-format on
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
const
auto
iM
=
get_block_id
()
*
Block_M
;
const
auto
x_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
const
auto
gamma_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
GammaDataType
*>
(
kargs
.
p_gamma
),
make_tuple
(
kargs
.
n
),
make_tuple
(
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
kPadM
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_N
>
{}),
{
0
});
}();
auto
y_window
=
[
&
]()
{
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
YDataType
*>
(
kargs
.
p_y
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
auto
inv_rms_window
=
[
&
]()
{
if
constexpr
(
kSaveInvRms
)
{
const
auto
inv_rms_m
=
[
&
]()
{
const
auto
inv_rms_dram_naive
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
static_cast
<
InvRmsDataType
*>
(
kargs
.
p_invRms
),
make_tuple
(
kargs
.
m
),
number
<
1
>
{});
return
pad_tensor_view
(
inv_rms_dram_naive
,
make_tuple
(
number
<
Block_M
>
{}),
sequence
<
kPadM
>
{});
}();
return
make_tile_window
(
inv_rms_m
,
make_tuple
(
number
<
Block_M
>
{}),
{
iM
});
}
else
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{}));
}();
__shared__
char
smem
[
GetSmemSize
()];
Pipeline
{}(
x_window
,
gamma_window
,
y_window
,
inv_rms_window
,
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
kargs
.
n
,
smem
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/
layer
norm2d/kernel/
layer
norm2d_fwd_shape.hpp
→
include/ck_tile/ops/
rms
norm2d/kernel/
rms
norm2d_fwd_shape.hpp
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -42,7 +42,7 @@ template <typename BlockTile_, // block size, seq<M, N>
typename
Vector_
,
// contiguous pixels(vector size) along seq<M, N>
index_t
BlockSize_
=
warpSize
*
reduce_on_sequence
(
WarpPerBlock_
{}
,
multiplies
{}
,
number
<
1
>{})
>
struct
Layer
norm2dShape
struct
Rms
norm2dShape
{
// block size
static
constexpr
index_t
Block_M
=
BlockTile_
::
at
(
number
<
0
>
{});
...
...
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
namespace
ck_tile
{
struct
Rmsnorm2dFwdPipelineDefaultPolicy
{
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeXBlockTileDistribution
()
{
using
S
=
typename
Problem
::
BlockShape
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
S
::
Repeat_M
,
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
,
S
::
Vector_M
>
,
sequence
<
S
::
Repeat_N
,
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
,
S
::
Vector_N
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
2
,
2
>>
,
sequence
<
1
,
1
,
2
,
2
>
,
sequence
<
0
,
3
,
0
,
3
>>
{});
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeGammaBlockTileDistribution
()
{
using
S
=
typename
Problem
::
BlockShape
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
>
,
tuple
<
sequence
<
S
::
Repeat_N
,
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
,
S
::
Vector_N
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
3
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2d
()
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
XDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockReduce2d
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dSync
()
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
XDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockReduce2dSync
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dCrossWarpSync
()
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
XDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockReduce2dCrossWarpSync
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
if
constexpr
(
Problem
::
kNeedCrossWarpSync
)
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
XDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
using
block_reduce2d
=
BlockReduce2d
<
P_
>
;
using
x_block_tile
=
decltype
(
make_static_distributed_tensor
<
typename
Problem
::
XDataType
>
(
MakeXBlockTileDistribution
<
Problem
>
()));
using
y_block_tile
=
decltype
(
block_reduce2d
::
template
MakeYBlockTile
<
x_block_tile
>());
return
GetBlockReduce2dCrossWarpSync
<
Problem
>
().
template
GetSmemSize
<
y_block_tile
>();
}
else
{
return
1
;
// zero size arrays are an extension
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
Rmsnorm2dFwdPipelineDefaultPolicy
>
struct
Rmsnorm2dFwdPipelineOnePass
{
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
InvRmsDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvRmsDataType
>
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveInvRms
=
Problem
::
kSaveInvRms
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockRmsnorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
return
"bpr_op"
;
// block per row
else
return
"wpr_op"
;
// warp per row
}();
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
XWindow
,
typename
GammaWindow
,
typename
YWindow
,
typename
InvRmsWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
GammaWindow
&
gamma_window_
,
YWindow
&
y_window
,
InvRmsWindow
&
inv_rms_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
{
const
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
const
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
Policy
::
template
MakeGammaBlockTileDistribution
<
Problem
>());
auto
reduce_square_sum_func
=
ReduceOp
::
SquareAdd
{};
auto
reduce_sum_func
=
ReduceOp
::
Add
{};
auto
block_reduce2d
=
Policy
::
template
GetBlockReduce2d
<
Problem
>();
auto
block_reduce2d_sync
=
Policy
::
template
GetBlockReduce2dSync
<
Problem
>();
auto
block_reduce2d_cross_warp_sync
=
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
const
auto
x
=
load_tile
(
x_window
);
// load gamma (TODO: support no gamma?)
const
auto
gamma
=
load_tile
(
gamma_window
);
// compute mean square each-thread->cross-lane->cross-warp
auto
square_sum
=
block_reduce2d
(
x
,
reduce_square_sum_func
.
GetIdentityValue
<
ComputeDataType
>
(),
reduce_square_sum_func
);
block_reduce2d_sync
(
square_sum
,
reduce_sum_func
);
block_reduce2d_cross_warp_sync
(
square_sum
,
smem
,
reduce_sum_func
);
// compute inv-rms
auto
inv_rms
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
/
row_size
+
epsilon
));
},
square_sum
);
if
constexpr
(
kSaveInvRms
)
store_tile
(
inv_rms_window
,
cast_tile
<
InvRmsDataType
>
(
inv_rms
));
// rmsnorm computation
auto
y
=
make_static_distributed_tensor
<
YDataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
,
inv_rms_
=
inv_rms
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
y_
=
x_
*
inv_rms_
[
i_idx
]
*
gamma_
;
y
(
idx
)
=
type_convert
<
YDataType
>
(
y_
);
});
store_tile
(
y_window
,
y
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
typename
XDataType_
,
typename
GammaDataType_
,
typename
ComputeDataType_
,
typename
YDataType_
,
typename
InvRmsDataType_
,
typename
BlockShape_
,
bool
kPadN_
,
bool
kSaveInvRms_
,
bool
kTwoPass_
>
struct
Rmsnorm2dFwdPipelineProblem
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
InvRmsDataType
=
remove_cvref_t
<
InvRmsDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveInvRms
=
kSaveInvRms_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
};
}
// namespace ck_tile
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
Rmsnorm2dFwdPipelineDefaultPolicy
>
struct
Rmsnorm2dFwdPipelineTwoPass
{
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
InvRmsDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvRmsDataType
>
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveInvRms
=
Problem
::
kSaveInvRms
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockRmsnorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
return
"bpr_tp"
;
// block per row
else
return
"wpr_tp"
;
// warp per row
}();
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
XWindow
,
typename
GammaWindow
,
typename
YWindow
,
typename
InvRmsWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
GammaWindow
&
gamma_window_
,
YWindow
&
y_window
,
InvRmsWindow
&
inv_rms_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
{
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
Policy
::
template
MakeGammaBlockTileDistribution
<
Problem
>());
// Problem::BlockShape
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
index_t
num_n_tile_iteration
=
__builtin_amdgcn_readfirstlane
(
integer_divide_ceil
(
row_size
,
Block_N
));
auto
reduce_square_sum_func
=
ReduceOp
::
SquareAdd
{};
auto
reduce_sum_func
=
ReduceOp
::
Add
{};
auto
block_reduce2d
=
Policy
::
template
GetBlockReduce2d
<
Problem
>();
auto
block_reduce2d_sync
=
Policy
::
template
GetBlockReduce2dSync
<
Problem
>();
auto
block_reduce2d_cross_warp_sync
=
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
using
XTensorType
=
decltype
(
load_tile
(
x_window
));
auto
square_sum
=
block_reduce2d
.
template
MakeYBlockTile
<
XTensorType
>();
set_tile
(
square_sum
,
reduce_square_sum_func
.
GetIdentityValue
<
ComputeDataType
>
());
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
const
auto
x
=
load_tile
(
x_window
);
block_reduce2d
(
x
,
square_sum
,
reduce_square_sum_func
);
move_tile_window
(
x_window
,
{
0
,
Block_N
});
}
block_reduce2d_sync
(
square_sum
,
reduce_sum_func
);
block_reduce2d_cross_warp_sync
(
square_sum
,
smem
,
reduce_sum_func
);
// compute inv-rms
auto
inv_rms
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
/
row_size
+
epsilon
));
},
square_sum
);
if
constexpr
(
kSaveInvRms
)
store_tile
(
inv_rms_window
,
cast_tile
<
InvRmsDataType
>
(
inv_rms
));
// reverse read x to reuse cache
ck_tile
::
index_t
stride_to_right_most_window
=
row_size
%
Block_N
==
0
?
row_size
-
Block_N
:
row_size
-
row_size
%
Block_N
;
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
gamma_window
,
{
stride_to_right_most_window
});
move_tile_window
(
y_window
,
{
0
,
stride_to_right_most_window
});
// rmsnorm computation
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
const
auto
x
=
load_tile
(
x_window
);
// load gamma/beta (TODO: support no gamma/beta?)
const
auto
gamma
=
load_tile
(
gamma_window
);
auto
y
=
make_static_distributed_tensor
<
YDataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
,
inv_rms_
=
inv_rms
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
y_
=
x_
*
inv_rms_
[
i_idx
]
*
gamma_
;
y
(
idx
)
=
type_convert
<
YDataType
>
(
y_
);
});
store_tile
(
y_window
,
y
);
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
gamma_window
,
{
-
Block_N
});
move_tile_window
(
y_window
,
{
0
,
-
Block_N
});
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/softmax.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/softmax/block/block_softmax_2d.hpp"
#include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
Prev
1
…
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment