Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c2ea75ed
Commit
c2ea75ed
authored
Jan 21, 2025
by
Jiming Ruan
Browse files
Add no variance support to block_norm_reduce
parent
b16fad32
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
227 additions
and
77 deletions
+227
-77
include/ck_tile/core/numeric/vector_type.hpp
include/ck_tile/core/numeric/vector_type.hpp
+15
-1
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
...rm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
+5
-1
include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp
include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp
+199
-69
include/ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp
..._tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp
+8
-6
No files found.
include/ck_tile/core/numeric/vector_type.hpp
View file @
c2ea75ed
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -78,11 +78,13 @@ using has_same_scalar_type = std::is_same<typename vector_traits<remove_cvref_t<
...
@@ -78,11 +78,13 @@ using has_same_scalar_type = std::is_same<typename vector_traits<remove_cvref_t<
// attention! 2 vector type could be just the same type
// attention! 2 vector type could be just the same type
// fp64
// fp64
using
fp64_t
=
double
;
using
fp64_t
=
double
;
using
fp64x1_t
=
double
__attribute__
((
ext_vector_type
(
1
)));
using
fp64x2_t
=
double
__attribute__
((
ext_vector_type
(
2
)));
using
fp64x2_t
=
double
__attribute__
((
ext_vector_type
(
2
)));
using
fp64x4_t
=
double
__attribute__
((
ext_vector_type
(
4
)));
using
fp64x4_t
=
double
__attribute__
((
ext_vector_type
(
4
)));
// fp32
// fp32
using
fp32_t
=
float
;
using
fp32_t
=
float
;
using
fp32x1_t
=
float
__attribute__
((
ext_vector_type
(
1
)));
using
fp32x2_t
=
float
__attribute__
((
ext_vector_type
(
2
)));
using
fp32x2_t
=
float
__attribute__
((
ext_vector_type
(
2
)));
using
fp32x4_t
=
float
__attribute__
((
ext_vector_type
(
4
)));
using
fp32x4_t
=
float
__attribute__
((
ext_vector_type
(
4
)));
using
fp32x8_t
=
float
__attribute__
((
ext_vector_type
(
8
)));
using
fp32x8_t
=
float
__attribute__
((
ext_vector_type
(
8
)));
...
@@ -92,6 +94,7 @@ using fp32x64_t = float __attribute__((ext_vector_type(64)));
...
@@ -92,6 +94,7 @@ using fp32x64_t = float __attribute__((ext_vector_type(64)));
// fp16
// fp16
// using fp16_t = ...
// using fp16_t = ...
using
fp16x1_t
=
_Float16
__attribute__
((
ext_vector_type
(
1
)));
using
fp16x2_t
=
_Float16
__attribute__
((
ext_vector_type
(
2
)));
using
fp16x2_t
=
_Float16
__attribute__
((
ext_vector_type
(
2
)));
using
fp16x4_t
=
_Float16
__attribute__
((
ext_vector_type
(
4
)));
using
fp16x4_t
=
_Float16
__attribute__
((
ext_vector_type
(
4
)));
using
fp16x8_t
=
_Float16
__attribute__
((
ext_vector_type
(
8
)));
using
fp16x8_t
=
_Float16
__attribute__
((
ext_vector_type
(
8
)));
...
@@ -101,6 +104,7 @@ using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
...
@@ -101,6 +104,7 @@ using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
// bf16
// bf16
// using bf16_t = ...
// using bf16_t = ...
using
bf16x1_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
1
)));
using
bf16x2_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
2
)));
using
bf16x2_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
2
)));
using
bf16x4_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
4
)));
using
bf16x4_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
4
)));
using
bf16x8_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
8
)));
using
bf16x8_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
8
)));
...
@@ -110,6 +114,7 @@ using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64)));
...
@@ -110,6 +114,7 @@ using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64)));
// i32
// i32
// using int32_t = ...
// using int32_t = ...
using
int32x1_t
=
int32_t
__attribute__
((
ext_vector_type
(
1
)));
using
int32x2_t
=
int32_t
__attribute__
((
ext_vector_type
(
2
)));
using
int32x2_t
=
int32_t
__attribute__
((
ext_vector_type
(
2
)));
using
int32x4_t
=
int32_t
__attribute__
((
ext_vector_type
(
4
)));
using
int32x4_t
=
int32_t
__attribute__
((
ext_vector_type
(
4
)));
using
int32x8_t
=
int32_t
__attribute__
((
ext_vector_type
(
8
)));
using
int32x8_t
=
int32_t
__attribute__
((
ext_vector_type
(
8
)));
...
@@ -119,6 +124,7 @@ using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
...
@@ -119,6 +124,7 @@ using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
// u32
// u32
// using uint32_t = ...
// using uint32_t = ...
using
uint32x1_t
=
uint32_t
__attribute__
((
ext_vector_type
(
1
)));
using
uint32x2_t
=
uint32_t
__attribute__
((
ext_vector_type
(
2
)));
using
uint32x2_t
=
uint32_t
__attribute__
((
ext_vector_type
(
2
)));
using
uint32x4_t
=
uint32_t
__attribute__
((
ext_vector_type
(
4
)));
using
uint32x4_t
=
uint32_t
__attribute__
((
ext_vector_type
(
4
)));
using
uint32x8_t
=
uint32_t
__attribute__
((
ext_vector_type
(
8
)));
using
uint32x8_t
=
uint32_t
__attribute__
((
ext_vector_type
(
8
)));
...
@@ -128,6 +134,7 @@ using uint32x64_t = uint32_t __attribute__((ext_vector_type(64)));
...
@@ -128,6 +134,7 @@ using uint32x64_t = uint32_t __attribute__((ext_vector_type(64)));
// i16
// i16
// using int16_t = ...
// using int16_t = ...
using
int16x1_t
=
int16_t
__attribute__
((
ext_vector_type
(
1
)));
using
int16x2_t
=
int16_t
__attribute__
((
ext_vector_type
(
2
)));
using
int16x2_t
=
int16_t
__attribute__
((
ext_vector_type
(
2
)));
using
int16x4_t
=
int16_t
__attribute__
((
ext_vector_type
(
4
)));
using
int16x4_t
=
int16_t
__attribute__
((
ext_vector_type
(
4
)));
using
int16x8_t
=
int16_t
__attribute__
((
ext_vector_type
(
8
)));
using
int16x8_t
=
int16_t
__attribute__
((
ext_vector_type
(
8
)));
...
@@ -137,6 +144,7 @@ using int16x64_t = int16_t __attribute__((ext_vector_type(64)));
...
@@ -137,6 +144,7 @@ using int16x64_t = int16_t __attribute__((ext_vector_type(64)));
// u16
// u16
// using uint16_t
// using uint16_t
using
uint16x1_t
=
uint16_t
__attribute__
((
ext_vector_type
(
1
)));
using
uint16x2_t
=
uint16_t
__attribute__
((
ext_vector_type
(
2
)));
using
uint16x2_t
=
uint16_t
__attribute__
((
ext_vector_type
(
2
)));
using
uint16x4_t
=
uint16_t
__attribute__
((
ext_vector_type
(
4
)));
using
uint16x4_t
=
uint16_t
__attribute__
((
ext_vector_type
(
4
)));
using
uint16x8_t
=
uint16_t
__attribute__
((
ext_vector_type
(
8
)));
using
uint16x8_t
=
uint16_t
__attribute__
((
ext_vector_type
(
8
)));
...
@@ -146,6 +154,7 @@ using uint16x64_t = uint16_t __attribute__((ext_vector_type(64)));
...
@@ -146,6 +154,7 @@ using uint16x64_t = uint16_t __attribute__((ext_vector_type(64)));
// i8
// i8
// using int8_t
// using int8_t
using
int8x1_t
=
int8_t
__attribute
((
ext_vector_type
(
1
)));
using
int8x2_t
=
int8_t
__attribute
((
ext_vector_type
(
2
)));
using
int8x2_t
=
int8_t
__attribute
((
ext_vector_type
(
2
)));
using
int8x4_t
=
int8_t
__attribute
((
ext_vector_type
(
4
)));
using
int8x4_t
=
int8_t
__attribute
((
ext_vector_type
(
4
)));
using
int8x8_t
=
int8_t
__attribute
((
ext_vector_type
(
8
)));
using
int8x8_t
=
int8_t
__attribute
((
ext_vector_type
(
8
)));
...
@@ -155,6 +164,7 @@ using int8x64_t = int8_t __attribute((ext_vector_type(64)));
...
@@ -155,6 +164,7 @@ using int8x64_t = int8_t __attribute((ext_vector_type(64)));
// ui8
// ui8
// using uint8_t
// using uint8_t
using
uint8x1_t
=
uint8_t
__attribute
((
ext_vector_type
(
1
)));
using
uint8x2_t
=
uint8_t
__attribute
((
ext_vector_type
(
2
)));
using
uint8x2_t
=
uint8_t
__attribute
((
ext_vector_type
(
2
)));
using
uint8x4_t
=
uint8_t
__attribute
((
ext_vector_type
(
4
)));
using
uint8x4_t
=
uint8_t
__attribute
((
ext_vector_type
(
4
)));
using
uint8x8_t
=
uint8_t
__attribute
((
ext_vector_type
(
8
)));
using
uint8x8_t
=
uint8_t
__attribute
((
ext_vector_type
(
8
)));
...
@@ -165,6 +175,7 @@ using uint8x64_t = uint8_t __attribute((ext_vector_type(64)));
...
@@ -165,6 +175,7 @@ using uint8x64_t = uint8_t __attribute((ext_vector_type(64)));
#if CK_TILE_USE_CUSTOM_DATA_TYPE
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// f8
// f8
// using fp8_t
// using fp8_t
using
fp8x1_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
1
)));
using
fp8x2_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
2
)));
using
fp8x2_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
2
)));
using
fp8x4_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
4
)));
using
fp8x4_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
4
)));
using
fp8x8_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
8
)));
using
fp8x8_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
8
)));
...
@@ -174,6 +185,7 @@ using fp8x64_t = fp8_raw_t __attribute((ext_vector_type(64)));
...
@@ -174,6 +185,7 @@ using fp8x64_t = fp8_raw_t __attribute((ext_vector_type(64)));
// bf8
// bf8
// using bf8_t
// using bf8_t
using
bf8x1_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
1
)));
using
bf8x2_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
2
)));
using
bf8x2_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
2
)));
using
bf8x4_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
4
)));
using
bf8x4_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
4
)));
using
bf8x8_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
8
)));
using
bf8x8_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
8
)));
...
@@ -183,6 +195,7 @@ using bf8x64_t = bf8_raw_t __attribute((ext_vector_type(64)));
...
@@ -183,6 +195,7 @@ using bf8x64_t = bf8_raw_t __attribute((ext_vector_type(64)));
#else
#else
// f8
// f8
// using fp8_t
// using fp8_t
using
fp8x1_t
=
fp8_t
__attribute
((
ext_vector_type
(
1
)));
using
fp8x2_t
=
fp8_t
__attribute
((
ext_vector_type
(
2
)));
using
fp8x2_t
=
fp8_t
__attribute
((
ext_vector_type
(
2
)));
using
fp8x4_t
=
fp8_t
__attribute
((
ext_vector_type
(
4
)));
using
fp8x4_t
=
fp8_t
__attribute
((
ext_vector_type
(
4
)));
using
fp8x8_t
=
fp8_t
__attribute
((
ext_vector_type
(
8
)));
using
fp8x8_t
=
fp8_t
__attribute
((
ext_vector_type
(
8
)));
...
@@ -192,6 +205,7 @@ using fp8x64_t = fp8_t __attribute((ext_vector_type(64)));
...
@@ -192,6 +205,7 @@ using fp8x64_t = fp8_t __attribute((ext_vector_type(64)));
// bf8
// bf8
// using bf8_t
// using bf8_t
using
bf8x1_t
=
bf8_t
__attribute
((
ext_vector_type
(
1
)));
using
bf8x2_t
=
bf8_t
__attribute
((
ext_vector_type
(
2
)));
using
bf8x2_t
=
bf8_t
__attribute
((
ext_vector_type
(
2
)));
using
bf8x4_t
=
bf8_t
__attribute
((
ext_vector_type
(
4
)));
using
bf8x4_t
=
bf8_t
__attribute
((
ext_vector_type
(
4
)));
using
bf8x8_t
=
bf8_t
__attribute
((
ext_vector_type
(
8
)));
using
bf8x8_t
=
bf8_t
__attribute
((
ext_vector_type
(
8
)));
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
View file @
c2ea75ed
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -48,6 +48,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
...
@@ -48,6 +48,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
typename
Problem
::
BlockShape
,
true
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kWelford
>
;
Problem
::
Traits
::
kWelford
>
;
return
BlockNormReduce
<
P_
>
{};
return
BlockNormReduce
<
P_
>
{};
...
@@ -59,6 +60,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
...
@@ -59,6 +60,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
typename
Problem
::
BlockShape
,
true
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kWelford
>
;
Problem
::
Traits
::
kWelford
>
;
...
@@ -71,6 +73,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
...
@@ -71,6 +73,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
typename
Problem
::
BlockShape
,
true
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kWelford
>
;
Problem
::
Traits
::
kWelford
>
;
...
@@ -85,6 +88,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
...
@@ -85,6 +88,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
typename
Problem
::
BlockShape
,
true
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kWelford
>
;
Problem
::
Traits
::
kWelford
>
;
...
...
include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp
View file @
c2ea75ed
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -14,11 +14,11 @@ struct BlockNormReduce
...
@@ -14,11 +14,11 @@ struct BlockNormReduce
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
XDataType
=
typename
Problem
::
XDataType
;
using
XDataType
=
typename
Problem
::
XDataType
;
using
ComputeDataType
=
typename
Problem
::
ComputeDataType
;
using
ComputeDataType
=
typename
Problem
::
ComputeDataType
;
static
constexpr
bool
kComputeVariance
=
Problem
::
kComputeVariance
;
static
constexpr
bool
kFastFDiv
=
Problem
::
kFastFDiv
;
static
constexpr
bool
kFastFDiv
=
Problem
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
kWelford
;
static
constexpr
bool
kWelford
=
Problem
::
kWelford
;
CK_TILE_DEVICE
constexpr
BlockNormReduce
()
{}
private:
// [CAUSION] - max_count_ is to deal with the padding problem
// [CAUSION] - max_count_ is to deal with the padding problem
// max_count_ is depend on caller, eg: naive and splitN norm_reduce will have different
// max_count_ is depend on caller, eg: naive and splitN norm_reduce will have different
// calculation of max_count_
// calculation of max_count_
...
@@ -26,7 +26,7 @@ struct BlockNormReduce
...
@@ -26,7 +26,7 @@ struct BlockNormReduce
template
<
typename
XDistributedTensor_
,
template
<
typename
XDistributedTensor_
,
typename
MeanDistributedTensor_
,
typename
MeanDistributedTensor_
,
typename
VarDistributedTensor_
>
typename
VarDistributedTensor_
>
CK_TILE_DEVICE
void
operator
()
(
const
XDistributedTensor_
&
x_tensor
,
CK_TILE_DEVICE
void
Impl
(
const
XDistributedTensor_
&
x_tensor
,
MeanDistributedTensor_
&
mean_tensor
,
MeanDistributedTensor_
&
mean_tensor
,
VarDistributedTensor_
&
var_tensor
,
VarDistributedTensor_
&
var_tensor
,
int
&
cur_count_
,
// -> prefer init as zero
int
&
cur_count_
,
// -> prefer init as zero
...
@@ -36,6 +36,8 @@ struct BlockNormReduce
...
@@ -36,6 +36,8 @@ struct BlockNormReduce
constexpr
auto
I1
=
number
<
1
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
constexpr
auto
spans
=
XDistributedTensor_
::
get_distributed_spans
();
constexpr
auto
spans
=
XDistributedTensor_
::
get_distributed_spans
();
constexpr
bool
computeVariance
=
!
std
::
is_same
<
VarDistributedTensor_
,
null_tensor
>::
value
&&
kComputeVariance
;
sweep_tile_span
(
spans
[
I1
],
[
&
](
auto
dstr_idx_i1
)
{
sweep_tile_span
(
spans
[
I1
],
[
&
](
auto
dstr_idx_i1
)
{
if
(
cur_count_
<
max_count_
)
if
(
cur_count_
<
max_count_
)
...
@@ -47,6 +49,8 @@ struct BlockNormReduce
...
@@ -47,6 +49,8 @@ struct BlockNormReduce
auto
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_tensor
[
in_dstr_idx
]);
auto
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_tensor
[
in_dstr_idx
]);
if
(
kWelford
)
if
(
kWelford
)
{
if
constexpr
(
computeVariance
)
{
{
welford_update
(
mean_tensor
(
out_dstr_idx
),
welford_update
(
mean_tensor
(
out_dstr_idx
),
var_tensor
(
out_dstr_idx
),
var_tensor
(
out_dstr_idx
),
...
@@ -55,15 +59,48 @@ struct BlockNormReduce
...
@@ -55,15 +59,48 @@ struct BlockNormReduce
constant
<
kFastFDiv
>
{});
constant
<
kFastFDiv
>
{});
}
}
else
else
{
welford_update
(
mean_tensor
(
out_dstr_idx
),
x
,
cur_count_
,
constant
<
kFastFDiv
>
{});
}
}
else
{
{
mean_tensor
(
out_dstr_idx
)
+=
x
;
mean_tensor
(
out_dstr_idx
)
+=
x
;
if
constexpr
(
computeVariance
)
{
var_tensor
(
out_dstr_idx
)
+=
x
*
x
;
var_tensor
(
out_dstr_idx
)
+=
x
*
x
;
}
}
}
});
});
}
}
});
});
}
}
public:
CK_TILE_DEVICE
constexpr
BlockNormReduce
()
{}
template
<
typename
XDistributedTensor_
,
typename
MeanDistributedTensor_
,
typename
VarDistributedTensor_
>
CK_TILE_DEVICE
void
operator
()(
const
XDistributedTensor_
&
x_tensor
,
MeanDistributedTensor_
&
mean_tensor
,
VarDistributedTensor_
&
var_tensor
,
int
&
cur_count_
,
const
int
&
max_count_
)
{
Impl
(
x_tensor
,
mean_tensor
,
var_tensor
,
cur_count_
,
max_count_
);
}
template
<
typename
XDistributedTensor_
,
typename
MeanDistributedTensor_
>
CK_TILE_DEVICE
void
operator
()(
const
XDistributedTensor_
&
x_tensor
,
MeanDistributedTensor_
&
mean_tensor
,
int
&
cur_count_
,
const
int
&
max_count_
)
{
Impl
(
x_tensor
,
mean_tensor
,
null_tensor
{},
cur_count_
,
max_count_
);
}
template
<
typename
XDistributedTensor_
>
template
<
typename
XDistributedTensor_
>
CK_TILE_DEVICE
static
auto
MakeMeanVarBlockTile
()
CK_TILE_DEVICE
static
auto
MakeMeanVarBlockTile
()
{
{
...
@@ -82,12 +119,13 @@ struct BlockNormReduce
...
@@ -82,12 +119,13 @@ struct BlockNormReduce
return
tensor
;
return
tensor
;
}
}
template
<
typename
XDistributedTensor_
>
template
<
bool
kComputeVariance
,
typename
XDistributedTensor_
>
CK_TILE_DEVICE
auto
CK_TILE_DEVICE
auto
operator
()(
const
XDistributedTensor_
&
x_tensor
,
int
&
cur_count_
,
const
int
&
max_count_
)
operator
()(
const
XDistributedTensor_
&
x_tensor
,
int
&
cur_count_
,
const
int
&
max_count_
)
{
{
auto
mean_tensor
=
MakeMeanVarBlockTile
<
XDistributedTensor_
>
();
auto
mean_tensor
=
MakeMeanVarBlockTile
<
XDistributedTensor_
>
();
auto
var_tensor
=
MakeMeanVarBlockTile
<
XDistributedTensor_
>
();
auto
var_tensor
=
kComputeVariance
?
MakeMeanVarBlockTile
<
XDistributedTensor_
>
()
:
null_tensor
{};
clear_tile
(
mean_tensor
);
clear_tile
(
mean_tensor
);
clear_tile
(
var_tensor
);
clear_tile
(
var_tensor
);
...
@@ -101,12 +139,14 @@ template <typename Problem_, typename Policy_ = void>
...
@@ -101,12 +139,14 @@ template <typename Problem_, typename Policy_ = void>
struct
BlockNormReduceSync
struct
BlockNormReduceSync
{
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Problem
=
remove_cvref_t
<
Problem_
>
;
static
constexpr
bool
kComputeVariance
=
Problem
::
kComputeVariance
;
static
constexpr
bool
kFastFDiv
=
Problem
::
kFastFDiv
;
static
constexpr
bool
kFastFDiv
=
Problem
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
kWelford
;
static
constexpr
bool
kWelford
=
Problem
::
kWelford
;
private:
template
<
typename
MeanDistributedTensor_
,
typename
VarDistributedTensor_
>
template
<
typename
MeanDistributedTensor_
,
typename
VarDistributedTensor_
>
CK_TILE_DEVICE
void
CK_TILE_DEVICE
void
operator
()
(
MeanDistributedTensor_
&
mean_tensor
,
VarDistributedTensor_
&
var_tensor
,
int
&
count
)
Impl
(
MeanDistributedTensor_
&
mean_tensor
,
VarDistributedTensor_
&
var_tensor
,
int
&
count
)
{
{
using
Dstr
=
typename
MeanDistributedTensor_
::
StaticTileDistribution
;
using
Dstr
=
typename
MeanDistributedTensor_
::
StaticTileDistribution
;
using
DstrEncode
=
typename
Dstr
::
DstrEncode
;
using
DstrEncode
=
typename
Dstr
::
DstrEncode
;
...
@@ -120,19 +160,23 @@ struct BlockNormReduceSync
...
@@ -120,19 +160,23 @@ struct BlockNormReduceSync
constexpr
index_t
idim_p_lane
=
NDimP
-
1
;
constexpr
index_t
idim_p_lane
=
NDimP
-
1
;
constexpr
bool
computeVariance
=
!
std
::
is_same
<
VarDistributedTensor_
,
null_tensor
>::
value
&&
kComputeVariance
;
// const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
// const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
// const auto rs_idx =
// const auto rs_idx =
// mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
// mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
constexpr
index_t
thread_buf_size
=
MeanDistributedTensor_
::
get_thread_buffer_size
();
constexpr
index_t
thread_buf_size
=
MeanDistributedTensor_
::
get_thread_buffer_size
();
static_assert
(
thread_buf_size
==
VarDistributedTensor_
::
get_thread_buffer_size
());
static_assert
((
computeVariance
==
false
)
||
(
thread_buf_size
==
VarDistributedTensor_
::
get_thread_buffer_size
()));
const
int
original_count
=
count
;
const
int
original_count
=
count
;
// loop over thread data
// loop over thread data
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i
)
{
auto
v_local_mean
=
mean_tensor
.
get_thread_buffer
()[
i
];
auto
v_local_mean
=
mean_tensor
.
get_thread_buffer
()[
i
];
auto
v_local_var
=
var_tensor
.
get_thread_buffer
()[
i
];
auto
v_local_var
=
computeVariance
?
var_tensor
.
get_thread_buffer
()[
i
]
:
0
;
auto
v_local_count
=
original_count
;
auto
v_local_count
=
original_count
;
// cross-lane reduce for replication
// cross-lane reduce for replication
...
@@ -161,12 +205,15 @@ struct BlockNormReduceSync
...
@@ -161,12 +205,15 @@ struct BlockNormReduceSync
// pull data from remote lane
// pull data from remote lane
const
auto
v_remote_mean
=
warp_shuffle
(
v_local_mean
,
src_lane
);
const
auto
v_remote_mean
=
warp_shuffle
(
v_local_mean
,
src_lane
);
const
auto
v_remote_var
=
warp_shuffle
(
v_local_var
,
src_lane
);
const
auto
v_remote_var
=
computeVariance
?
warp_shuffle
(
v_local_var
,
src_lane
)
:
0
;
if
(
kWelford
)
if
(
kWelford
)
{
{
const
auto
v_remote_count
=
warp_shuffle
(
v_local_count
,
src_lane
);
const
auto
v_remote_count
=
warp_shuffle
(
v_local_count
,
src_lane
);
// norm_reduce merge
// norm_reduce merge
if
constexpr
(
computeVariance
)
{
welford_merge
(
v_local_mean
,
welford_merge
(
v_local_mean
,
v_local_var
,
v_local_var
,
v_local_count
,
v_local_count
,
...
@@ -176,10 +223,22 @@ struct BlockNormReduceSync
...
@@ -176,10 +223,22 @@ struct BlockNormReduceSync
constant
<
kFastFDiv
>
{});
constant
<
kFastFDiv
>
{});
}
}
else
else
{
welford_merge
(
v_local_mean
,
v_local_count
,
v_remote_mean
,
v_remote_count
,
constant
<
kFastFDiv
>
{});
}
}
else
{
{
v_local_mean
+=
v_remote_mean
;
v_local_mean
+=
v_remote_mean
;
if
constexpr
(
computeVariance
)
{
v_local_var
+=
v_remote_var
;
v_local_var
+=
v_remote_var
;
}
}
}
});
});
}
}
});
});
...
@@ -192,6 +251,20 @@ struct BlockNormReduceSync
...
@@ -192,6 +251,20 @@ struct BlockNormReduceSync
}
}
});
});
}
}
public:
template
<
typename
MeanDistributedTensor_
,
typename
VarDistributedTensor_
>
CK_TILE_DEVICE
void
operator
()(
MeanDistributedTensor_
&
mean_tensor
,
VarDistributedTensor_
&
var_tensor
,
int
&
count
)
{
Impl
(
mean_tensor
,
var_tensor
,
count
);
}
template
<
typename
MeanDistributedTensor_
>
CK_TILE_DEVICE
void
operator
()(
MeanDistributedTensor_
&
mean_tensor
,
int
&
count
)
{
Impl
(
mean_tensor
,
null_tensor
{},
count
);
}
};
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
template
<
typename
Problem_
,
typename
Policy_
=
void
>
...
@@ -199,9 +272,13 @@ struct BlockNormReduceCrossWarpSync
...
@@ -199,9 +272,13 @@ struct BlockNormReduceCrossWarpSync
{
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
BlockShape
=
typename
Problem
::
BlockShape
;
using
BlockShape
=
typename
Problem
::
BlockShape
;
static
constexpr
bool
kComputeVariance
=
Problem
::
kComputeVariance
;
static
constexpr
bool
kFastFDiv
=
Problem
::
kFastFDiv
;
static
constexpr
bool
kFastFDiv
=
Problem
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
kWelford
;
static
constexpr
bool
kWelford
=
Problem
::
kWelford
;
using
smem_dtype
=
std
::
conditional_t
<
kWelford
,
fp32x4_t
,
fp32x2_t
>
;
using
smem_dtype
=
std
::
conditional_t
<
kWelford
,
typename
std
::
conditional
<
kComputeVariance
,
fp32x4_t
,
fp32x2_t
>::
type
,
typename
std
::
conditional
<
kComputeVariance
,
fp32x2_t
,
fp32x1_t
>::
type
>
;
template
<
typename
MeanDistributedTensor_
>
template
<
typename
MeanDistributedTensor_
>
CK_TILE_DEVICE
static
constexpr
index_t
GetReduceWarps
()
CK_TILE_DEVICE
static
constexpr
index_t
GetReduceWarps
()
...
@@ -234,7 +311,12 @@ struct BlockNormReduceCrossWarpSync
...
@@ -234,7 +311,12 @@ struct BlockNormReduceCrossWarpSync
{
{
// constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
// constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
// data need to exchange is very small, we just pack mean+var+count -> 4dword
// data need to exchange is very small, we just pack mean+var+count -> 4dword if var should
// be calculated, or pack mean+count -> 2dword if only mean is taken into account.
// Additionally, count is not exchanged if Welford algorithm is not used.
constexpr
index_t
num_dw
=
kWelford
?
(
kComputeVariance
?
4
:
2
)
:
(
kComputeVariance
?
2
:
1
);
constexpr
index_t
thread_buf_size
=
MeanDistributedTensor_
::
get_thread_buffer_size
();
constexpr
index_t
thread_buf_size
=
MeanDistributedTensor_
::
get_thread_buffer_size
();
// we need to store all data from every wave into smem
// we need to store all data from every wave into smem
...
@@ -251,11 +333,12 @@ struct BlockNormReduceCrossWarpSync
...
@@ -251,11 +333,12 @@ struct BlockNormReduceCrossWarpSync
//
//
// -> also store data from every wave into LDS
// -> also store data from every wave into LDS
constexpr
index_t
num_warps
=
BlockShape
::
BlockSize
/
warpSize
;
constexpr
index_t
num_warps
=
BlockShape
::
BlockSize
/
warpSize
;
return
num_warps
*
4
*
thread_buf_size
*
sizeof
(
float
);
return
num_warps
*
num_dw
*
thread_buf_size
*
sizeof
(
float
);
}
}
private:
template
<
typename
MeanDistributedTensor_
,
typename
VarDistributedTensor_
>
template
<
typename
MeanDistributedTensor_
,
typename
VarDistributedTensor_
>
CK_TILE_DEVICE
void
operator
()
(
MeanDistributedTensor_
&
mean_tensor
,
CK_TILE_DEVICE
void
Impl
(
MeanDistributedTensor_
&
mean_tensor
,
VarDistributedTensor_
&
var_tensor
,
VarDistributedTensor_
&
var_tensor
,
int
&
count
,
int
&
count
,
void
*
smem
)
void
*
smem
)
...
@@ -265,11 +348,17 @@ struct BlockNormReduceCrossWarpSync
...
@@ -265,11 +348,17 @@ struct BlockNormReduceCrossWarpSync
// using DstrEncode = typename Dstr::DstrEncode;
// using DstrEncode = typename Dstr::DstrEncode;
// using DstrEncodeDetail = typename DstrEncode::detail;
// using DstrEncodeDetail = typename DstrEncode::detail;
static_assert
(
std
::
is_same_v
<
Dstr
,
typename
VarDistributedTensor_
::
StaticTileDistribution
>
,
constexpr
bool
computeVariance
=
!
std
::
is_same
<
VarDistributedTensor_
,
null_tensor
>::
value
&&
kComputeVariance
;
static_assert
(
(
computeVariance
==
false
)
||
std
::
is_same_v
<
Dstr
,
typename
VarDistributedTensor_
::
StaticTileDistribution
>
,
"wrong!"
);
"wrong!"
);
constexpr
index_t
thread_buf_size
=
MeanDistributedTensor_
::
get_thread_buffer_size
();
constexpr
index_t
thread_buf_size
=
MeanDistributedTensor_
::
get_thread_buffer_size
();
static_assert
(
thread_buf_size
==
VarDistributedTensor_
::
get_thread_buffer_size
());
static_assert
((
computeVariance
==
false
)
||
(
thread_buf_size
==
VarDistributedTensor_
::
get_thread_buffer_size
()));
// Note: we always pack everything into fp32x4
// Note: we always pack everything into fp32x4
smem_dtype
*
smem_ptr
=
reinterpret_cast
<
smem_dtype
*>
(
smem
);
smem_dtype
*
smem_ptr
=
reinterpret_cast
<
smem_dtype
*>
(
smem
);
...
@@ -289,11 +378,18 @@ struct BlockNormReduceCrossWarpSync
...
@@ -289,11 +378,18 @@ struct BlockNormReduceCrossWarpSync
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i
)
{
smem_dtype
local_scratch_
;
smem_dtype
local_scratch_
;
local_scratch_
[
0
]
=
bit_cast
<
float
>
(
mean_tensor
.
get_thread_buffer
()[
i
]);
local_scratch_
[
0
]
=
bit_cast
<
float
>
(
mean_tensor
.
get_thread_buffer
()[
i
]);
if
constexpr
(
kComputeVariance
)
{
local_scratch_
[
1
]
=
bit_cast
<
float
>
(
var_tensor
.
get_thread_buffer
()[
i
]);
local_scratch_
[
1
]
=
bit_cast
<
float
>
(
var_tensor
.
get_thread_buffer
()[
i
]);
if
(
kWelford
)
if
constexpr
(
kWelford
)
{
{
local_scratch_
[
2
]
=
bit_cast
<
float
>
(
count
);
local_scratch_
[
2
]
=
bit_cast
<
float
>
(
count
);
}
}
}
else
if
constexpr
(
kWelford
)
{
local_scratch_
[
1
]
=
bit_cast
<
float
>
(
count
);
}
smem_ptr
[
smem_offset
+
i
*
num_warps
]
=
local_scratch_
;
smem_ptr
[
smem_offset
+
i
*
num_warps
]
=
local_scratch_
;
});
});
}
}
...
@@ -317,19 +413,22 @@ struct BlockNormReduceCrossWarpSync
...
@@ -317,19 +413,22 @@ struct BlockNormReduceCrossWarpSync
// TODO: use descriptor for this
// TODO: use descriptor for this
auto
v_local
=
all_scratch
[
i_0
*
num_reduce_warps
];
auto
v_local
=
all_scratch
[
i_0
*
num_reduce_warps
];
auto
v_local_mean
=
bit_cast
<
DataType
>
(
v_local
[
0
]);
auto
v_local_mean
=
bit_cast
<
DataType
>
(
v_local
[
0
]);
auto
v_local_var
=
bit_cast
<
DataType
>
(
v_local
[
1
]);
auto
v_local_var
=
kComputeVariance
?
bit_cast
<
DataType
>
(
v_local
[
1
])
:
0
;
int
v_local_count
=
kWelford
?
bit_cast
<
int
>
(
v_local
[
2
])
:
0
;
int
v_local_count
=
kWelford
?
(
kComputeVariance
?
bit_cast
<
int
>
(
v_local
[
2
])
:
bit_cast
<
int
>
(
v_local
[
1
]))
:
0
;
// further reduce mean/var
// further reduce mean/var
static_for
<
0
,
num_reduce_warps
-
1
,
1
>
{}([
&
](
auto
i_1_n1
)
{
static_for
<
0
,
num_reduce_warps
-
1
,
1
>
{}([
&
](
auto
i_1_n1
)
{
constexpr
auto
i_1
=
number
<
i_1_n1
+
1
>
{};
constexpr
auto
i_1
=
number
<
i_1_n1
+
1
>
{};
const
smem_dtype
v_remote
=
all_scratch
[
i_0
*
num_reduce_warps
+
i_1
];
const
smem_dtype
v_remote
=
all_scratch
[
i_0
*
num_reduce_warps
+
i_1
];
const
auto
v_remote_mean
=
bit_cast
<
DataType
>
(
v_remote
[
0
]);
const
auto
v_remote_mean
=
bit_cast
<
DataType
>
(
v_remote
[
0
]);
const
auto
v_remote_var
=
bit_cast
<
DataType
>
(
v_remote
[
1
]);
const
auto
v_remote_var
=
kComputeVariance
?
bit_cast
<
DataType
>
(
v_remote
[
1
])
:
0
;
if
(
kWelford
)
if
constexpr
(
kWelford
)
{
if
constexpr
(
kComputeVariance
)
{
{
const
auto
v_remote_count
=
bit_cast
<
int
>
(
v_remote
[
2
]);
const
auto
v_remote_count
=
bit_cast
<
int
>
(
v_remote
[
2
]);
welford_merge
(
v_local_mean
,
welford_merge
(
v_local_mean
,
v_local_var
,
v_local_var
,
v_local_count
,
v_local_count
,
...
@@ -339,18 +438,49 @@ struct BlockNormReduceCrossWarpSync
...
@@ -339,18 +438,49 @@ struct BlockNormReduceCrossWarpSync
constant
<
kFastFDiv
>
{});
constant
<
kFastFDiv
>
{});
}
}
else
else
{
const
auto
v_remote_count
=
bit_cast
<
int
>
(
v_remote
[
1
]);
welford_merge
(
v_local_mean
,
v_local_count
,
v_remote_mean
,
v_remote_count
,
constant
<
kFastFDiv
>
{});
}
}
else
{
{
v_local_mean
+=
v_remote_mean
;
v_local_mean
+=
v_remote_mean
;
if
constexpr
(
kComputeVariance
)
{
v_local_var
+=
v_remote_var
;
v_local_var
+=
v_remote_var
;
}
}
}
});
});
mean_tensor
.
get_thread_buffer
()(
i_0
)
=
v_local_mean
;
mean_tensor
.
get_thread_buffer
()(
i_0
)
=
v_local_mean
;
var_tensor
.
get_thread_buffer
()(
i_0
)
=
v_local_var
;
var_tensor
.
get_thread_buffer
()(
i_0
)
=
v_local_var
;
if
(
kWelford
)
if
constexpr
(
kWelford
)
{
count
=
v_local_count
;
count
=
v_local_count
;
}
});
});
}
}
public:
template
<
typename
MeanDistributedTensor_
,
typename
VarDistributedTensor_
>
CK_TILE_DEVICE
void
operator
()(
MeanDistributedTensor_
&
mean_tensor
,
VarDistributedTensor_
&
var_tensor
,
int
&
count
,
void
*
smem
)
{
Impl
(
mean_tensor
,
var_tensor
,
count
,
smem
);
}
template
<
typename
MeanDistributedTensor_
>
CK_TILE_DEVICE
void
operator
()(
MeanDistributedTensor_
&
mean_tensor
,
int
&
count
,
void
*
smem
)
{
Impl
(
mean_tensor
,
null_tensor
{},
count
,
smem
);
}
};
};
// compute the max count for a last dim reduce
// compute the max count for a last dim reduce
...
...
include/ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp
View file @
c2ea75ed
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -10,6 +10,7 @@ namespace ck_tile {
...
@@ -10,6 +10,7 @@ namespace ck_tile {
template
<
typename
XDataType_
,
template
<
typename
XDataType_
,
typename
ComputeDataType_
,
typename
ComputeDataType_
,
typename
BlockShape_
,
typename
BlockShape_
,
bool
kComputeVariance_
,
bool
kFastFDiv_
,
bool
kFastFDiv_
,
bool
kWelford_
>
bool
kWelford_
>
struct
BlockNormReduceProblem
struct
BlockNormReduceProblem
...
@@ -17,6 +18,7 @@ struct BlockNormReduceProblem
...
@@ -17,6 +18,7 @@ struct BlockNormReduceProblem
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kComputeVariance
=
kComputeVariance_
;
static
constexpr
bool
kFastFDiv
=
kFastFDiv_
;
static
constexpr
bool
kFastFDiv
=
kFastFDiv_
;
static
constexpr
bool
kWelford
=
kWelford_
;
static
constexpr
bool
kWelford
=
kWelford_
;
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment