Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c2ea75ed
Commit
c2ea75ed
authored
Jan 21, 2025
by
Jiming Ruan
Browse files
Add no variance support to block_norm_reduce
parent
b16fad32
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
227 additions
and
77 deletions
+227
-77
include/ck_tile/core/numeric/vector_type.hpp
include/ck_tile/core/numeric/vector_type.hpp
+15
-1
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
...rm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
+5
-1
include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp
include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp
+199
-69
include/ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp
..._tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp
+8
-6
No files found.
include/ck_tile/core/numeric/vector_type.hpp
View file @
c2ea75ed
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -78,11 +78,13 @@ using has_same_scalar_type = std::is_same<typename vector_traits<remove_cvref_t<
...
@@ -78,11 +78,13 @@ using has_same_scalar_type = std::is_same<typename vector_traits<remove_cvref_t<
// attention! 2 vector type could be just the same type
// attention! 2 vector type could be just the same type
// fp64
// fp64
using
fp64_t
=
double
;
using
fp64_t
=
double
;
using
fp64x1_t
=
double
__attribute__
((
ext_vector_type
(
1
)));
using
fp64x2_t
=
double
__attribute__
((
ext_vector_type
(
2
)));
using
fp64x2_t
=
double
__attribute__
((
ext_vector_type
(
2
)));
using
fp64x4_t
=
double
__attribute__
((
ext_vector_type
(
4
)));
using
fp64x4_t
=
double
__attribute__
((
ext_vector_type
(
4
)));
// fp32
// fp32
using
fp32_t
=
float
;
using
fp32_t
=
float
;
using
fp32x1_t
=
float
__attribute__
((
ext_vector_type
(
1
)));
using
fp32x2_t
=
float
__attribute__
((
ext_vector_type
(
2
)));
using
fp32x2_t
=
float
__attribute__
((
ext_vector_type
(
2
)));
using
fp32x4_t
=
float
__attribute__
((
ext_vector_type
(
4
)));
using
fp32x4_t
=
float
__attribute__
((
ext_vector_type
(
4
)));
using
fp32x8_t
=
float
__attribute__
((
ext_vector_type
(
8
)));
using
fp32x8_t
=
float
__attribute__
((
ext_vector_type
(
8
)));
...
@@ -92,6 +94,7 @@ using fp32x64_t = float __attribute__((ext_vector_type(64)));
...
@@ -92,6 +94,7 @@ using fp32x64_t = float __attribute__((ext_vector_type(64)));
// fp16
// fp16
// using fp16_t = ...
// using fp16_t = ...
using
fp16x1_t
=
_Float16
__attribute__
((
ext_vector_type
(
1
)));
using
fp16x2_t
=
_Float16
__attribute__
((
ext_vector_type
(
2
)));
using
fp16x2_t
=
_Float16
__attribute__
((
ext_vector_type
(
2
)));
using
fp16x4_t
=
_Float16
__attribute__
((
ext_vector_type
(
4
)));
using
fp16x4_t
=
_Float16
__attribute__
((
ext_vector_type
(
4
)));
using
fp16x8_t
=
_Float16
__attribute__
((
ext_vector_type
(
8
)));
using
fp16x8_t
=
_Float16
__attribute__
((
ext_vector_type
(
8
)));
...
@@ -101,6 +104,7 @@ using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
...
@@ -101,6 +104,7 @@ using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
// bf16
// bf16
// using bf16_t = ...
// using bf16_t = ...
using
bf16x1_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
1
)));
using
bf16x2_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
2
)));
using
bf16x2_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
2
)));
using
bf16x4_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
4
)));
using
bf16x4_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
4
)));
using
bf16x8_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
8
)));
using
bf16x8_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
8
)));
...
@@ -110,6 +114,7 @@ using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64)));
...
@@ -110,6 +114,7 @@ using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64)));
// i32
// i32
// using int32_t = ...
// using int32_t = ...
using
int32x1_t
=
int32_t
__attribute__
((
ext_vector_type
(
1
)));
using
int32x2_t
=
int32_t
__attribute__
((
ext_vector_type
(
2
)));
using
int32x2_t
=
int32_t
__attribute__
((
ext_vector_type
(
2
)));
using
int32x4_t
=
int32_t
__attribute__
((
ext_vector_type
(
4
)));
using
int32x4_t
=
int32_t
__attribute__
((
ext_vector_type
(
4
)));
using
int32x8_t
=
int32_t
__attribute__
((
ext_vector_type
(
8
)));
using
int32x8_t
=
int32_t
__attribute__
((
ext_vector_type
(
8
)));
...
@@ -119,6 +124,7 @@ using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
...
@@ -119,6 +124,7 @@ using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
// u32
// u32
// using uint32_t = ...
// using uint32_t = ...
using
uint32x1_t
=
uint32_t
__attribute__
((
ext_vector_type
(
1
)));
using
uint32x2_t
=
uint32_t
__attribute__
((
ext_vector_type
(
2
)));
using
uint32x2_t
=
uint32_t
__attribute__
((
ext_vector_type
(
2
)));
using
uint32x4_t
=
uint32_t
__attribute__
((
ext_vector_type
(
4
)));
using
uint32x4_t
=
uint32_t
__attribute__
((
ext_vector_type
(
4
)));
using
uint32x8_t
=
uint32_t
__attribute__
((
ext_vector_type
(
8
)));
using
uint32x8_t
=
uint32_t
__attribute__
((
ext_vector_type
(
8
)));
...
@@ -128,6 +134,7 @@ using uint32x64_t = uint32_t __attribute__((ext_vector_type(64)));
...
@@ -128,6 +134,7 @@ using uint32x64_t = uint32_t __attribute__((ext_vector_type(64)));
// i16
// i16
// using int16_t = ...
// using int16_t = ...
using
int16x1_t
=
int16_t
__attribute__
((
ext_vector_type
(
1
)));
using
int16x2_t
=
int16_t
__attribute__
((
ext_vector_type
(
2
)));
using
int16x2_t
=
int16_t
__attribute__
((
ext_vector_type
(
2
)));
using
int16x4_t
=
int16_t
__attribute__
((
ext_vector_type
(
4
)));
using
int16x4_t
=
int16_t
__attribute__
((
ext_vector_type
(
4
)));
using
int16x8_t
=
int16_t
__attribute__
((
ext_vector_type
(
8
)));
using
int16x8_t
=
int16_t
__attribute__
((
ext_vector_type
(
8
)));
...
@@ -137,6 +144,7 @@ using int16x64_t = int16_t __attribute__((ext_vector_type(64)));
...
@@ -137,6 +144,7 @@ using int16x64_t = int16_t __attribute__((ext_vector_type(64)));
// u16
// u16
// using uint16_t
// using uint16_t
using
uint16x1_t
=
uint16_t
__attribute__
((
ext_vector_type
(
1
)));
using
uint16x2_t
=
uint16_t
__attribute__
((
ext_vector_type
(
2
)));
using
uint16x2_t
=
uint16_t
__attribute__
((
ext_vector_type
(
2
)));
using
uint16x4_t
=
uint16_t
__attribute__
((
ext_vector_type
(
4
)));
using
uint16x4_t
=
uint16_t
__attribute__
((
ext_vector_type
(
4
)));
using
uint16x8_t
=
uint16_t
__attribute__
((
ext_vector_type
(
8
)));
using
uint16x8_t
=
uint16_t
__attribute__
((
ext_vector_type
(
8
)));
...
@@ -146,6 +154,7 @@ using uint16x64_t = uint16_t __attribute__((ext_vector_type(64)));
...
@@ -146,6 +154,7 @@ using uint16x64_t = uint16_t __attribute__((ext_vector_type(64)));
// i8
// i8
// using int8_t
// using int8_t
using
int8x1_t
=
int8_t
__attribute
((
ext_vector_type
(
1
)));
using
int8x2_t
=
int8_t
__attribute
((
ext_vector_type
(
2
)));
using
int8x2_t
=
int8_t
__attribute
((
ext_vector_type
(
2
)));
using
int8x4_t
=
int8_t
__attribute
((
ext_vector_type
(
4
)));
using
int8x4_t
=
int8_t
__attribute
((
ext_vector_type
(
4
)));
using
int8x8_t
=
int8_t
__attribute
((
ext_vector_type
(
8
)));
using
int8x8_t
=
int8_t
__attribute
((
ext_vector_type
(
8
)));
...
@@ -155,6 +164,7 @@ using int8x64_t = int8_t __attribute((ext_vector_type(64)));
...
@@ -155,6 +164,7 @@ using int8x64_t = int8_t __attribute((ext_vector_type(64)));
// ui8
// ui8
// using uint8_t
// using uint8_t
using
uint8x1_t
=
uint8_t
__attribute
((
ext_vector_type
(
1
)));
using
uint8x2_t
=
uint8_t
__attribute
((
ext_vector_type
(
2
)));
using
uint8x2_t
=
uint8_t
__attribute
((
ext_vector_type
(
2
)));
using
uint8x4_t
=
uint8_t
__attribute
((
ext_vector_type
(
4
)));
using
uint8x4_t
=
uint8_t
__attribute
((
ext_vector_type
(
4
)));
using
uint8x8_t
=
uint8_t
__attribute
((
ext_vector_type
(
8
)));
using
uint8x8_t
=
uint8_t
__attribute
((
ext_vector_type
(
8
)));
...
@@ -165,6 +175,7 @@ using uint8x64_t = uint8_t __attribute((ext_vector_type(64)));
...
@@ -165,6 +175,7 @@ using uint8x64_t = uint8_t __attribute((ext_vector_type(64)));
#if CK_TILE_USE_CUSTOM_DATA_TYPE
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// f8
// f8
// using fp8_t
// using fp8_t
using
fp8x1_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
1
)));
using
fp8x2_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
2
)));
using
fp8x2_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
2
)));
using
fp8x4_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
4
)));
using
fp8x4_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
4
)));
using
fp8x8_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
8
)));
using
fp8x8_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
8
)));
...
@@ -174,6 +185,7 @@ using fp8x64_t = fp8_raw_t __attribute((ext_vector_type(64)));
...
@@ -174,6 +185,7 @@ using fp8x64_t = fp8_raw_t __attribute((ext_vector_type(64)));
// bf8
// bf8
// using bf8_t
// using bf8_t
using
bf8x1_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
1
)));
using
bf8x2_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
2
)));
using
bf8x2_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
2
)));
using
bf8x4_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
4
)));
using
bf8x4_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
4
)));
using
bf8x8_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
8
)));
using
bf8x8_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
8
)));
...
@@ -183,6 +195,7 @@ using bf8x64_t = bf8_raw_t __attribute((ext_vector_type(64)));
...
@@ -183,6 +195,7 @@ using bf8x64_t = bf8_raw_t __attribute((ext_vector_type(64)));
#else
#else
// f8
// f8
// using fp8_t
// using fp8_t
using
fp8x1_t
=
fp8_t
__attribute
((
ext_vector_type
(
1
)));
using
fp8x2_t
=
fp8_t
__attribute
((
ext_vector_type
(
2
)));
using
fp8x2_t
=
fp8_t
__attribute
((
ext_vector_type
(
2
)));
using
fp8x4_t
=
fp8_t
__attribute
((
ext_vector_type
(
4
)));
using
fp8x4_t
=
fp8_t
__attribute
((
ext_vector_type
(
4
)));
using
fp8x8_t
=
fp8_t
__attribute
((
ext_vector_type
(
8
)));
using
fp8x8_t
=
fp8_t
__attribute
((
ext_vector_type
(
8
)));
...
@@ -192,6 +205,7 @@ using fp8x64_t = fp8_t __attribute((ext_vector_type(64)));
...
@@ -192,6 +205,7 @@ using fp8x64_t = fp8_t __attribute((ext_vector_type(64)));
// bf8
// bf8
// using bf8_t
// using bf8_t
using
bf8x1_t
=
bf8_t
__attribute
((
ext_vector_type
(
1
)));
using
bf8x2_t
=
bf8_t
__attribute
((
ext_vector_type
(
2
)));
using
bf8x2_t
=
bf8_t
__attribute
((
ext_vector_type
(
2
)));
using
bf8x4_t
=
bf8_t
__attribute
((
ext_vector_type
(
4
)));
using
bf8x4_t
=
bf8_t
__attribute
((
ext_vector_type
(
4
)));
using
bf8x8_t
=
bf8_t
__attribute
((
ext_vector_type
(
8
)));
using
bf8x8_t
=
bf8_t
__attribute
((
ext_vector_type
(
8
)));
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
View file @
c2ea75ed
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -48,6 +48,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
...
@@ -48,6 +48,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
typename
Problem
::
BlockShape
,
true
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kWelford
>
;
Problem
::
Traits
::
kWelford
>
;
return
BlockNormReduce
<
P_
>
{};
return
BlockNormReduce
<
P_
>
{};
...
@@ -59,6 +60,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
...
@@ -59,6 +60,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
typename
Problem
::
BlockShape
,
true
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kWelford
>
;
Problem
::
Traits
::
kWelford
>
;
...
@@ -71,6 +73,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
...
@@ -71,6 +73,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
typename
Problem
::
BlockShape
,
true
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kWelford
>
;
Problem
::
Traits
::
kWelford
>
;
...
@@ -85,6 +88,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
...
@@ -85,6 +88,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
typename
Problem
::
BlockShape
,
true
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kWelford
>
;
Problem
::
Traits
::
kWelford
>
;
...
...
include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp
View file @
c2ea75ed
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp
View file @
c2ea75ed
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -10,15 +10,17 @@ namespace ck_tile {
...
@@ -10,15 +10,17 @@ namespace ck_tile {
template
<
typename
XDataType_
,
template
<
typename
XDataType_
,
typename
ComputeDataType_
,
typename
ComputeDataType_
,
typename
BlockShape_
,
typename
BlockShape_
,
bool
kComputeVariance_
,
bool
kFastFDiv_
,
bool
kFastFDiv_
,
bool
kWelford_
>
bool
kWelford_
>
struct
BlockNormReduceProblem
struct
BlockNormReduceProblem
{
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kFastFDiv
=
kFastFDiv_
;
static
constexpr
bool
kComputeVariance
=
kComputeVariance_
;
static
constexpr
bool
kWelford
=
kWelford_
;
static
constexpr
bool
kFastFDiv
=
kFastFDiv_
;
static
constexpr
bool
kWelford
=
kWelford_
;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment