Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5c9cba82
Commit
5c9cba82
authored
Oct 26, 2024
by
felix
Browse files
add l
parent
54f0e6f4
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
398 additions
and
2 deletions
+398
-2
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_bwd_gamma_beta_kernel.hpp
.../layernorm2d/kernel/layernorm2d_bwd_gamma_beta_kernel.hpp
+200
-0
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+2
-2
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_gamma_beta.hpp
...e/ops/layernorm2d/pipeline/layernorm2d_bwd_gamma_beta.hpp
+82
-0
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp
...rm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp
+50
-0
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_problem.hpp
...layernorm2d/pipeline/layernorm2d_bwd_pipeline_problem.hpp
+35
-0
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
...layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
+29
-0
No files found.
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_bwd_gamma_beta_kernel.hpp
0 → 100644
View file @
5c9cba82
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace
ck_tile
{
// host side args
struct
Layernorm2dBwdGammaBetaHostArgs
{
const
void
*
p_dY
;
const
void
*
p_mean
;
const
void
*
p_invStd
;
void
*
p_dGamma
;
void
*
p_dBeta
;
void
*
p_yMul
;
index_t
m
;
index_t
n
;
index_t
stride
;
// row_stride
};
// TODO: Extract some type to wrapper class
template
<
typename
Pipeline_
>
struct
Layernorm2dBwdGammaBeta
{
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
GammaDataType
=
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
BetaDataType
=
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YDataType
=
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
MeanDataType
=
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
bool
kPadM
=
false
;
// always no need to pad along M
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
index_t
ThreadPerWarp_N
=
Problem
::
BlockShape
::
ThreadPerWarp_N
;
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
struct
Kargs
{
const
void
*
p_dY
;
const
void
*
p_mean
;
const
void
*
p_invStd
;
void
*
p_dGamma
;
void
*
p_dBeta
;
void
*
p_yMul
;
index_t
m
;
index_t
n
;
index_t
stride
;
// row_stride
};
using
Hargs
=
Layernorm2dBwdGammaBetaHostArgs
;
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
return
Kargs
{
hargs
.
p_dY
,
hargs
.
p_mean
,
hargs
.
p_invStd
,
hargs
.
p_dGamma
,
hargs
.
p_dBeta
,
hargs
.
p_yMul
,
hargs
.
m
,
hargs
.
n
,
hargs
.
stride
};
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
{
return
(
hargs
.
m
+
Block_M
-
1
)
/
Block_M
;
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockShape
::
BlockSize
;
}
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
// clang-format on
// in byte
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Pipeline
::
GetSmemSize
();
}
CK_TILE_HOST
static
std
::
string
GetName
()
{
// clang-format off
using
S_
=
typename
Problem
::
BlockShape
;
auto
surfix
=
[
&
]
()
{
std
::
string
n
;
if
(
kPadN
)
n
+=
"_pn"
;
if
(
kSaveMeanInvStd
)
n
+=
"_mv"
;
if
(
kTwoPass
)
n
+=
"_2p"
;
return
n
;
}();
#define _SS_ std::string
#define _TS_ std::to_string
return
_SS_
(
"layernorm2d_fwd_"
)
+
_SS_
(
t2s
<
XDataType
>::
name
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
_SS_
(
Pipeline
::
name
)
+
surfix
;
#undef _SS_
#undef _TS_
// clang-format on
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
const
auto
iM
=
get_block_id
()
*
Block_M
;
const
auto
dy_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
YDataType
*>
(
kargs
.
p_dY
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
));
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
// check the max count dynamically
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
false
,
false
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
const
auto
mean_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
MeanDataType
*>
(
kargs
.
p_mean
),
make_tuple
(
kargs
.
m
),
make_tuple
(
1
));
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{}),
sequence
<
false
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{}),
{
0
});
}();
const
auto
invstd_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
MeanDataType
*>
(
kargs
.
p_invStd
),
make_tuple
(
kargs
.
m
),
make_tuple
(
1
));
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{}),
sequence
<
false
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{}),
{
0
});
}();
const
auto
dgamma_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
GammaDataType
*>
(
kargs
.
p_dGamma
),
make_tuple
(
kargs
.
n
),
make_tuple
(
1
));
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
false
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_N
>
{}),
{
0
});
}();
const
auto
dbeta_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
BetaDataType
*>
(
kargs
.
p_dBeta
),
make_tuple
(
kargs
.
n
),
make_tuple
(
1
));
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
false
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
0
});
}();
__shared__
char
smem
[
GetSmemSize
()];
Pipeline
{}(
dy_window
,
mean_window
,
invstd_window
,
dgamma_window
,
dbeta_window
,
kargs
.
n
,
smem
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
5c9cba82
...
...
@@ -94,7 +94,7 @@ struct Layernorm2dFwd
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
{
return
(
hargs
.
m
+
Block_M
-
1
)
/
Block_M
;
return
dim3
((
hargs
.
n
+
Block_N
-
1
)
/
Block_N
,
(
hargs
.
m
+
Block_M
-
1
)
/
Block_M
)
;
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockShape
::
BlockSize
;
}
...
...
@@ -124,7 +124,7 @@ struct Layernorm2dFwd
#define _SS_ std::string
#define _TS_ std::to_string
return
_SS_
(
"layernorm2d_
f
wd_"
)
+
_SS_
(
t2s
<
XDataType
>::
name
)
+
"_"
+
return
_SS_
(
"layernorm2d_
b
wd_
gamma_beta_
"
)
+
_SS_
(
t2s
<
XDataType
>::
name
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
_SS_
(
Pipeline
::
name
)
+
surfix
;
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_gamma_beta.hpp
0 → 100644
View file @
5c9cba82
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
Layernorm2dBwdGammaBetaPipelineDefaultPolicy
>
struct
Layernorm2dBwdGammaBetaPipeline
{
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
BetaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
MeanDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dBwdGammaBetaProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
const
char
*
name
=
[]()
{
return
"bwd_gamma_beta"
}();
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
DYWindow
,
typename
MeanWindow
,
typename
InvStdWindow
,
typename
DGammaWindow
,
typename
DBetaWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
DYWindow
&
dy_window_
,
const
MeanWindow
&
mean_window_
,
const
InvStdWindow
&
inv_std_window_
,
DGammaWindow
&
gamma_window_
,
DBetaWindow
&
beta_window_
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
{
const
auto
dy_window
=
make_tile_window
(
dy_window_
,
Policy
::
template
MakeDyBlockTileDistribution
<
Problem
>());
const
auto
mean_window
=
make_tile_window
(
mean_window_
,
Policy
::
template
MakeMeanBlockTileDistribution
<
Problem
>());
// const auto gamma_window = make_tile_window(
// gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
// const auto beta_window = make_tile_window(
// beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const
auto
dy
=
load_tile
(
dy_window
);
const
auto
mean
=
load_tile
(
mean_window
);
// layernorm computation
// auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
// sweep_tile(y, [&, mean_ = mean](auto idx) {
// constexpr auto i_idx = make_tuple(idx[number<0>{}]);
// constexpr auto j_idx = make_tuple(idx[number<1>{}]);
// const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
// const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
// const auto x_ = type_convert<ComputeDataType>(x[idx]);
// auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
// y(idx) = type_convert<YDataType>(y_);
// });
// store_tile(y_window, y);
}
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp
0 → 100644
View file @
5c9cba82
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/welford/block/block_welford_problem.hpp"
#include "ck_tile/ops/welford/block/block_welford.hpp"
namespace
ck_tile
{
struct
Layernorm2dBwdGammaBetaPipelineDefaultPolicy
{
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeDyBlockTileDistribution
()
{
using
S
=
typename
Problem
::
BlockShape
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
S
::
Repeat_M
,
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
>
,
sequence
<
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
2
,
1
>>
,
sequence
<
1
>
,
sequence
<
0
>>
{});
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeMeanBlockTileDistribution
()
{
using
S
=
typename
Problem
::
BlockShape
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
>
,
tuple
<
sequence
<
S
::
Repeat_M
,
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
2
,
1
>>
,
sequence
<
1
>
,
sequence
<
0
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
1
;
}
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_problem.hpp
0 → 100644
View file @
5c9cba82
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
typename
XDataType_
,
typename
GammaDataType_
,
typename
BetaDataType_
,
typename
ComputeDataType_
,
typename
YDataType_
,
typename
MeanDataType_
,
typename
InvStdDataType_
,
typename
BlockShape_
,
bool
kPadN_
,
bool
kSaveMeanInvStd_
,
bool
kTwoPass_
>
struct
Layernorm2dBwdGammaBetaPipelineProblem
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kPadN
=
kPadN_
;
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
View file @
5c9cba82
...
...
@@ -37,4 +37,33 @@ struct Layernorm2dFwdPipelineProblem
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
};
template
<
typename
XDataType_
,
typename
GammaDataType_
,
typename
BetaDataType_
,
typename
ComputeDataType_
,
typename
YDataType_
,
typename
MeanDataType_
,
typename
InvStdDataType_
,
typename
BlockShape_
,
bool
kPadN_
,
bool
kSaveMeanInvStd_
,
bool
kTwoPass_
>
struct
Layernorm2dFwdPipelineProblem
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
};
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment