Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
96568141
Commit
96568141
authored
Oct 14, 2024
by
rocking
Browse files
1. Add save mean and save std back
2. Move construction of tensor_view and tile_window to operator()
parent
e0b473b6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
194 additions
and
149 deletions
+194
-149
example/ck_tile/02_layernorm2d/example_layernorm2d_fwd.cpp
example/ck_tile/02_layernorm2d/example_layernorm2d_fwd.cpp
+18
-0
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
+2
-2
example/ck_tile/02_layernorm2d/layernorm_dispatch.hpp
example/ck_tile/02_layernorm2d/layernorm_dispatch.hpp
+2
-0
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+172
-147
No files found.
example/ck_tile/02_layernorm2d/example_layernorm2d_fwd.cpp
View file @
96568141
...
@@ -55,6 +55,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -55,6 +55,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host_ref
({
M
});
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host_ref
({
M
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host_ref
({
M
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host_ref
({
M
});
// TODO - move SAVE_MEAN_INV_STD to user args
#ifdef SAVE_MEAN_INV_STD
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host_dev
({
M
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host_dev
({
M
});
#endif
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
GammaDataType
>
{
-
.5
f
,
.5
f
}(
gamma_host
);
ck_tile
::
FillUniformDistribution
<
GammaDataType
>
{
-
.5
f
,
.5
f
}(
gamma_host
);
ck_tile
::
FillUniformDistribution
<
BetaDataType
>
{
-
.5
f
,
.5
f
}(
beta_host
);
ck_tile
::
FillUniformDistribution
<
BetaDataType
>
{
-
.5
f
,
.5
f
}(
beta_host
);
...
@@ -63,6 +69,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -63,6 +69,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
DeviceMem
gamma_buf
(
gamma_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
gamma_buf
(
gamma_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
beta_buf
(
beta_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
beta_buf
(
beta_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_buf
(
y_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_buf
(
y_host_dev
.
get_element_space_size_in_bytes
());
#ifdef SAVE_MEAN_INV_STD
ck_tile
::
DeviceMem
mean_buf
(
mean_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
invStd_buf
(
invStd_host_dev
.
get_element_space_size_in_bytes
());
#endif
x_buf
.
ToDevice
(
x_host
.
data
());
x_buf
.
ToDevice
(
x_host
.
data
());
gamma_buf
.
ToDevice
(
gamma_host
.
data
());
gamma_buf
.
ToDevice
(
gamma_host
.
data
());
...
@@ -74,6 +84,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -74,6 +84,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
gamma_buf
.
GetDeviceBuffer
(),
gamma_buf
.
GetDeviceBuffer
(),
beta_buf
.
GetDeviceBuffer
(),
beta_buf
.
GetDeviceBuffer
(),
y_buf
.
GetDeviceBuffer
(),
y_buf
.
GetDeviceBuffer
(),
#ifdef SAVE_MEAN_INV_STD
mean_buf
.
GetDeviceBuffer
(),
invStd_buf
.
GetDeviceBuffer
(),
#else
nullptr
,
nullptr
,
#endif
epsilon
,
epsilon
,
M
,
M
,
N
};
N
};
...
@@ -121,6 +138,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -121,6 +138,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
std
::
cout
<<
std
::
endl
<<
std
::
flush
;
std
::
cout
<<
std
::
endl
<<
std
::
flush
;
std
::
cout
<<
"pass = "
<<
pass
<<
std
::
endl
;
return
pass
;
return
pass
;
}
}
...
...
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
View file @
96568141
...
@@ -56,8 +56,8 @@ struct layernorm2d_fwd_args
...
@@ -56,8 +56,8 @@ struct layernorm2d_fwd_args
const
void
*
p_gamma
;
const
void
*
p_gamma
;
const
void
*
p_beta
;
const
void
*
p_beta
;
void
*
p_y
;
void
*
p_y
;
//
void* p_mean;
void
*
p_mean
;
//
void* p_invStd;
void
*
p_invStd
;
float
epsilon
;
float
epsilon
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
N
;
ck_tile
::
index_t
N
;
...
...
example/ck_tile/02_layernorm2d/layernorm_dispatch.hpp
View file @
96568141
...
@@ -59,6 +59,8 @@ struct layernorm_dispatch
...
@@ -59,6 +59,8 @@ struct layernorm_dispatch
param
.
p_gamma
,
param
.
p_gamma
,
param
.
p_beta
,
param
.
p_beta
,
param
.
p_y
,
param
.
p_y
,
param
.
p_mean
,
param
.
p_invStd
,
param
.
epsilon
,
param
.
epsilon
,
param
.
M
,
param
.
M
,
param
.
N
));
param
.
N
));
...
...
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
96568141
...
@@ -31,10 +31,15 @@ struct Layernorm2dFwd
...
@@ -31,10 +31,15 @@ struct Layernorm2dFwd
static
constexpr
ck_tile
::
index_t
kMPerBlock
=
Problem
::
BlockShape
::
kMPerBlock
;
static
constexpr
ck_tile
::
index_t
kMPerBlock
=
Problem
::
BlockShape
::
kMPerBlock
;
static
constexpr
ck_tile
::
index_t
kNPerBlock
=
Problem
::
BlockShape
::
kNPerBlock
;
static
constexpr
ck_tile
::
index_t
kNPerBlock
=
Problem
::
BlockShape
::
kNPerBlock
;
static
constexpr
bool
kPadM
=
false
;
// TODO - Problem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kTwoPass
=
Problem
::
kTwoPass
;
static
constexpr
bool
kTwoPass
=
Problem
::
kTwoPass
;
static
constexpr
ck_tile
::
index_t
kNThreadPerWarp
=
Problem
::
BlockShape
::
kNThreadPerWarp
;
static
constexpr
ck_tile
::
index_t
kNThreadPerWarp
=
Problem
::
BlockShape
::
kNThreadPerWarp
;
static
constexpr
ck_tile
::
index_t
kNPerThread
=
Problem
::
BlockShape
::
kNPerThread
;
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
struct
Kargs
struct
Kargs
{
{
...
@@ -43,8 +48,8 @@ struct Layernorm2dFwd
...
@@ -43,8 +48,8 @@ struct Layernorm2dFwd
const
void
*
p_beta
;
const
void
*
p_beta
;
void
*
p_y
;
void
*
p_y
;
//
void* p_mean;
void
*
p_mean
;
//
void* p_invStd;
void
*
p_invStd
;
float
epsilon
;
float
epsilon
;
...
@@ -150,53 +155,24 @@ struct Layernorm2dFwd
...
@@ -150,53 +155,24 @@ struct Layernorm2dFwd
return
iN0
*
S
::
kNPerThread
+
iN3
;
return
iN0
*
S
::
kNPerThread
+
iN3
;
}
}
template
<
bool
Cond
=
(
kHasGamma
&&
kHasBeta
)>
template
<
typename
XBlockWindow
,
CK_TILE_DEVICE
std
::
enable_if_t
<
Cond
>
OnePassLayernorm2dFwd
(
const
XDataType
*
p_x
,
typename
GammaBlockWindow
,
const
GammaDataType
*
p_gamma
,
typename
BetaBlockWindow
,
const
BetaDataType
*
p_beta
,
typename
YBlockWindow
,
YDataType
*
p_y
,
typename
MeanBlockWindow
,
const
ComputeDataType
epsilon
,
typename
InvStdBlockWindow
,
ck_tile
::
index_t
M
,
bool
Cond
=
(
kHasGamma
&&
kHasBeta
)>
ck_tile
::
index_t
N
)
const
CK_TILE_DEVICE
std
::
enable_if_t
<
Cond
>
OnePassLayernorm2dFwd
(
XBlockWindow
&
x_block_window
,
GammaBlockWindow
&
gamma_block_window
,
BetaBlockWindow
&
beta_block_window
,
YBlockWindow
&
y_block_window
,
MeanBlockWindow
&
mean_block_window
,
InvStdBlockWindow
&
inv_std_block_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
N
)
const
{
{
using
S
=
typename
Problem
::
BlockShape
;
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
const
auto
x_m_n
=
[
&
]()
{
const
auto
x_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_x
,
make_tuple
(
M
,
N
),
make_tuple
(
N
,
1
),
number
<
S
::
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
x_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
sequence
<
false
,
kPadN
>
{});
}();
const
auto
gamma_n
=
[
&
]()
{
const
auto
gamma_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_gamma
,
make_tuple
(
N
),
make_tuple
(
1
),
number
<
S
::
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
gamma_dram_naive
,
make_tuple
(
number
<
kNPerBlock
>
{}),
sequence
<
kPadN
>
{});
}();
const
auto
beta_n
=
[
&
]()
{
const
auto
gamma_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_beta
,
make_tuple
(
N
),
make_tuple
(
1
),
number
<
S
::
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
gamma_dram_naive
,
make_tuple
(
number
<
kNPerBlock
>
{}),
sequence
<
kPadN
>
{});
}();
const
auto
iM
=
get_block_id
()
*
kMPerBlock
;
constexpr
auto
xDstr
=
MakeXBlockTileDistribution
();
auto
x_block_window
=
make_tile_window
(
x_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
iM
,
0
},
xDstr
);
auto
intra_thread_count_last
=
GetLastloopLayerNormIntraLaneReduceCount
(
N
);
auto
intra_thread_count_last
=
GetLastloopLayerNormIntraLaneReduceCount
(
N
);
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford
{
intra_thread_count_last
};
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford
{
intra_thread_count_last
};
using
XTensorType
=
decltype
(
load_tile
(
x_block_window
));
using
XTensorType
=
decltype
(
load_tile
(
x_block_window
));
...
@@ -210,37 +186,21 @@ struct Layernorm2dFwd
...
@@ -210,37 +186,21 @@ struct Layernorm2dFwd
const
auto
x_block_tensor
=
load_tile
(
x_block_window
);
const
auto
x_block_tensor
=
load_tile
(
x_block_window
);
thread_welford
(
x_block_tensor
,
mean_compute_block_tensor
,
var_compute_block_tensor
);
thread_welford
(
x_block_tensor
,
mean_compute_block_tensor
,
var_compute_block_tensor
);
constexpr
auto
gammaDstr
=
MakeGammaBetaBlockTileDistribution
();
constexpr
auto
betaDstr
=
gammaDstr
;
auto
gamma_block_window
=
make_tile_window
(
gamma_n
,
make_tuple
(
number
<
kNPerBlock
>
{}),
{
0
},
gammaDstr
);
auto
beta_block_window
=
make_tile_window
(
beta_n
,
make_tuple
(
number
<
kNPerBlock
>
{}),
{
0
},
betaDstr
);
const
auto
gamma_block_tensor
=
load_tile
(
gamma_block_window
);
const
auto
beta_block_tensor
=
load_tile
(
beta_block_window
);
// TODO: support cross warp Welford
// TODO: support cross warp Welford
WarpMergeWelford
<
ComputeDataType
,
true
>
{}(
WarpMergeWelford
<
ComputeDataType
,
true
>
{}(
mean_compute_block_tensor
,
var_compute_block_tensor
,
thread_welford
.
cur_count_
);
mean_compute_block_tensor
,
var_compute_block_tensor
,
thread_welford
.
cur_count_
);
auto
inv_std_compute_block_tensor
=
InvSqrt
(
var_compute_block_tensor
,
epsilon
);
auto
inv_std_compute_block_tensor
=
InvSqrt
(
var_compute_block_tensor
,
epsilon
);
// TODO: Extract normalize pipeline
if
constexpr
(
kSaveMean
)
const
auto
y_m_n
=
[
&
]()
{
store_tile
(
mean_block_window
,
cast_tile
<
MeanDataType
>
(
mean_compute_block_tensor
));
const
auto
y_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
if
constexpr
(
kSaveInvStd
)
p_y
,
make_tuple
(
M
,
N
),
make_tuple
(
N
,
1
),
number
<
S
::
kNPerThread
>
{},
number
<
1
>
{});
store_tile
(
inv_std_block_window
,
cast_tile
<
InvStdDataType
>
(
inv_std_compute_block_tensor
));
return
pad_tensor_view
(
y_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
sequence
<
false
,
kPadN
>
{});
}();
auto
y_block_window
=
make_tile_window
(
// TODO: Extract normalize pipeline
y_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
iM
,
0
});
const
auto
gamma_block_tensor
=
load_tile
(
gamma_block_window
);
const
auto
beta_block_tensor
=
load_tile
(
beta_block_window
);
constexpr
auto
x_spans
=
decltype
(
x_block_tensor
)
::
get_distributed_spans
();
constexpr
auto
x_spans
=
decltype
(
x_block_tensor
)
::
get_distributed_spans
();
...
@@ -269,51 +229,24 @@ struct Layernorm2dFwd
...
@@ -269,51 +229,24 @@ struct Layernorm2dFwd
store_tile
(
y_block_window
,
y_block_tensor
);
store_tile
(
y_block_window
,
y_block_tensor
);
}
}
template
<
bool
Cond
=
(
kHasGamma
&&
kHasBeta
)>
template
<
typename
XBlockWindow
,
CK_TILE_DEVICE
std
::
enable_if_t
<
Cond
>
TwoPassLayernorm2dFwd
(
const
XDataType
*
p_x
,
typename
GammaBlockWindow
,
const
GammaDataType
*
p_gamma
,
typename
BetaBlockWindow
,
const
BetaDataType
*
p_beta
,
typename
YBlockWindow
,
YDataType
*
p_y
,
typename
MeanBlockWindow
,
const
ComputeDataType
epsilon
,
typename
InvStdBlockWindow
,
ck_tile
::
index_t
M
,
bool
Cond
=
(
kHasGamma
&&
kHasBeta
)>
ck_tile
::
index_t
N
)
const
CK_TILE_DEVICE
std
::
enable_if_t
<
Cond
>
TwoPassLayernorm2dFwd
(
XBlockWindow
&
x_block_window
,
GammaBlockWindow
&
gamma_block_window
,
BetaBlockWindow
&
beta_block_window
,
YBlockWindow
&
y_block_window
,
MeanBlockWindow
&
mean_block_window
,
InvStdBlockWindow
&
inv_std_block_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
N
)
const
{
{
using
S
=
typename
Problem
::
BlockShape
;
using
S
=
typename
Problem
::
BlockShape
;
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
const
auto
x_m_n
=
[
&
]()
{
const
auto
x_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_x
,
make_tuple
(
M
,
N
),
make_tuple
(
N
,
1
),
number
<
S
::
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
x_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
sequence
<
false
,
true
>
{});
}();
const
auto
gamma_n
=
[
&
]()
{
const
auto
gamma_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_gamma
,
make_tuple
(
N
),
make_tuple
(
1
),
number
<
S
::
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
gamma_dram_naive
,
make_tuple
(
number
<
kNPerBlock
>
{}),
sequence
<
true
>
{});
}();
const
auto
beta_n
=
[
&
]()
{
const
auto
gamma_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_beta
,
make_tuple
(
N
),
make_tuple
(
1
),
number
<
S
::
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
gamma_dram_naive
,
make_tuple
(
number
<
kNPerBlock
>
{}),
sequence
<
true
>
{});
}();
const
auto
iM
=
get_block_id
()
*
kMPerBlock
;
constexpr
auto
xDstr
=
MakeXBlockTileDistribution
();
auto
x_block_window
=
make_tile_window
(
x_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
iM
,
0
},
xDstr
);
index_t
num_n_tile_iteration
=
index_t
num_n_tile_iteration
=
__builtin_amdgcn_readfirstlane
((
N
+
kNPerBlock
-
1
)
/
kNPerBlock
);
__builtin_amdgcn_readfirstlane
((
N
+
kNPerBlock
-
1
)
/
kNPerBlock
);
...
@@ -352,27 +285,11 @@ struct Layernorm2dFwd
...
@@ -352,27 +285,11 @@ struct Layernorm2dFwd
auto
inv_std_compute_block_tensor
=
InvSqrt
(
var_compute_block_tensor
,
epsilon
);
auto
inv_std_compute_block_tensor
=
InvSqrt
(
var_compute_block_tensor
,
epsilon
);
// TODO: Extract normalize pipeline
if
constexpr
(
kSaveMean
)
const
auto
y_m_n
=
[
&
]()
{
store_tile
(
mean_block_window
,
cast_tile
<
MeanDataType
>
(
mean_compute_block_tensor
));
const
auto
y_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
if
constexpr
(
kSaveInvStd
)
p_y
,
make_tuple
(
M
,
N
),
make_tuple
(
N
,
1
),
number
<
S
::
kNPerThread
>
{},
number
<
1
>
{});
store_tile
(
inv_std_block_window
,
cast_tile
<
InvStdDataType
>
(
inv_std_compute_block_tensor
));
return
pad_tensor_view
(
y_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
sequence
<
false
,
true
>
{});
}();
auto
y_block_window
=
make_tile_window
(
y_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
iM
,
0
});
constexpr
auto
gammaDstr
=
MakeGammaBetaBlockTileDistribution
();
constexpr
auto
betaDstr
=
gammaDstr
;
auto
gamma_block_window
=
make_tile_window
(
gamma_n
,
make_tuple
(
number
<
kNPerBlock
>
{}),
{
0
},
gammaDstr
);
auto
beta_block_window
=
make_tile_window
(
beta_n
,
make_tuple
(
number
<
kNPerBlock
>
{}),
{
0
},
betaDstr
);
// reverse read x to reuse cache
// reverse read x to reuse cache
ck_tile
::
index_t
stride_to_right_most_window
=
ck_tile
::
index_t
stride_to_right_most_window
=
...
@@ -426,29 +343,137 @@ struct Layernorm2dFwd
...
@@ -426,29 +343,137 @@ struct Layernorm2dFwd
const
void
*
p_gamma
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
const
void
*
p_beta
,
void
*
p_y
,
void
*
p_y
,
void
*
p_mean
,
void
*
p_invStd
,
const
ComputeDataType
epsilon
,
const
ComputeDataType
epsilon
,
ck_tile
::
index_t
M
,
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
)
const
ck_tile
::
index_t
N
)
const
{
{
const
auto
x_m_n
=
[
&
]()
{
const
auto
x_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XDataType
*>
(
p_x
),
make_tuple
(
M
,
N
),
make_tuple
(
N
,
1
),
number
<
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
x_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
}();
const
auto
gamma_n
=
[
&
]()
{
const
auto
gamma_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
GammaDataType
*>
(
p_gamma
),
make_tuple
(
N
),
make_tuple
(
1
),
number
<
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
gamma_dram_naive
,
make_tuple
(
number
<
kNPerBlock
>
{}),
sequence
<
kPadN
>
{});
}();
const
auto
beta_n
=
[
&
]()
{
const
auto
gamma_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
BetaDataType
*>
(
p_beta
),
make_tuple
(
N
),
make_tuple
(
1
),
number
<
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
gamma_dram_naive
,
make_tuple
(
number
<
kNPerBlock
>
{}),
sequence
<
kPadN
>
{});
}();
const
auto
iM
=
get_block_id
()
*
kMPerBlock
;
constexpr
auto
xDstr
=
MakeXBlockTileDistribution
();
auto
x_block_window
=
make_tile_window
(
x_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
iM
,
0
},
xDstr
);
const
auto
y_m_n
=
[
&
]()
{
const
auto
y_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
YDataType
*>
(
p_y
),
make_tuple
(
M
,
N
),
make_tuple
(
N
,
1
),
number
<
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
y_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
}();
auto
y_block_window
=
make_tile_window
(
y_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
iM
,
0
});
constexpr
auto
gammaDstr
=
MakeGammaBetaBlockTileDistribution
();
constexpr
auto
betaDstr
=
gammaDstr
;
auto
gamma_block_window
=
make_tile_window
(
gamma_n
,
make_tuple
(
number
<
kNPerBlock
>
{}),
{
0
},
gammaDstr
);
auto
beta_block_window
=
make_tile_window
(
beta_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
0
},
betaDstr
);
auto
mean_block_window
=
[
&
]()
{
if
constexpr
(
kSaveMean
)
{
const
auto
mean_m
=
[
&
]()
{
const
auto
mean_dram_naive
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
static_cast
<
MeanDataType
*>
(
p_mean
),
make_tuple
(
M
),
number
<
1
>
{});
return
pad_tensor_view
(
mean_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{}),
sequence
<
kPadM
>
{});
}();
return
make_tile_window
(
mean_m
,
make_tuple
(
number
<
kMPerBlock
>
{}),
{
iM
});
}
else
return
make_null_tile_window
(
make_tuple
(
number
<
kMPerBlock
>
{}));
}();
auto
inv_std_block_window
=
[
&
]()
{
if
constexpr
(
kSaveInvStd
)
{
const
auto
inv_std_m
=
[
&
]()
{
const
auto
inv_std_dram_naive
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
static_cast
<
InvStdDataType
*>
(
p_invStd
),
make_tuple
(
M
),
number
<
1
>
{});
return
pad_tensor_view
(
inv_std_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{}),
sequence
<
kPadM
>
{});
}();
return
make_tile_window
(
inv_std_m
,
make_tuple
(
number
<
kMPerBlock
>
{}),
{
iM
});
}
else
return
make_null_tile_window
(
make_tuple
(
number
<
kMPerBlock
>
{}));
}();
if
constexpr
(
kTwoPass
)
if
constexpr
(
kTwoPass
)
{
{
TwoPassLayernorm2dFwd
(
static_cast
<
const
XDataType
*>
(
p_x
),
TwoPassLayernorm2dFwd
(
x_block_window
,
static_cast
<
const
GammaDataType
*>
(
p_gamma
),
gamma_block_window
,
static_cast
<
const
BetaDataType
*>
(
p_beta
),
beta_block_window
,
static_cast
<
YDataType
*>
(
p_y
),
y_block_window
,
mean_block_window
,
inv_std_block_window
,
static_cast
<
const
ComputeDataType
>
(
epsilon
),
static_cast
<
const
ComputeDataType
>
(
epsilon
),
M
,
N
);
N
);
}
}
else
else
{
{
OnePassLayernorm2dFwd
(
x_block_window
,
OnePassLayernorm2dFwd
(
static_cast
<
const
XDataType
*>
(
p_x
),
gamma_block_window
,
static_cast
<
const
GammaDataType
*>
(
p_gamma
),
beta_block_window
,
static_cast
<
const
BetaDataType
*>
(
p_beta
),
y_block_window
,
static_cast
<
YDataType
*>
(
p_y
),
mean_block_window
,
inv_std_block_window
,
static_cast
<
const
ComputeDataType
>
(
epsilon
),
static_cast
<
const
ComputeDataType
>
(
epsilon
),
M
,
N
);
N
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment