Commit c2ea75ed authored by Jiming Ruan's avatar Jiming Ruan
Browse files

Add no variance support to block_norm_reduce

parent b16fad32
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -78,11 +78,13 @@ using has_same_scalar_type = std::is_same<typename vector_traits<remove_cvref_t< ...@@ -78,11 +78,13 @@ using has_same_scalar_type = std::is_same<typename vector_traits<remove_cvref_t<
// attention! 2 vector type could be just the same type // attention! 2 vector type could be just the same type
// fp64 // fp64
using fp64_t = double; using fp64_t = double;
using fp64x1_t = double __attribute__((ext_vector_type(1)));
using fp64x2_t = double __attribute__((ext_vector_type(2))); using fp64x2_t = double __attribute__((ext_vector_type(2)));
using fp64x4_t = double __attribute__((ext_vector_type(4))); using fp64x4_t = double __attribute__((ext_vector_type(4)));
// fp32 // fp32
using fp32_t = float; using fp32_t = float;
using fp32x1_t = float __attribute__((ext_vector_type(1)));
using fp32x2_t = float __attribute__((ext_vector_type(2))); using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp32x4_t = float __attribute__((ext_vector_type(4))); using fp32x4_t = float __attribute__((ext_vector_type(4)));
using fp32x8_t = float __attribute__((ext_vector_type(8))); using fp32x8_t = float __attribute__((ext_vector_type(8)));
...@@ -92,6 +94,7 @@ using fp32x64_t = float __attribute__((ext_vector_type(64))); ...@@ -92,6 +94,7 @@ using fp32x64_t = float __attribute__((ext_vector_type(64)));
// fp16 // fp16
// using fp16_t = ... // using fp16_t = ...
using fp16x1_t = _Float16 __attribute__((ext_vector_type(1)));
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using fp16x4_t = _Float16 __attribute__((ext_vector_type(4))); using fp16x4_t = _Float16 __attribute__((ext_vector_type(4)));
using fp16x8_t = _Float16 __attribute__((ext_vector_type(8))); using fp16x8_t = _Float16 __attribute__((ext_vector_type(8)));
...@@ -101,6 +104,7 @@ using fp16x64_t = _Float16 __attribute__((ext_vector_type(64))); ...@@ -101,6 +104,7 @@ using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
// bf16 // bf16
// using bf16_t = ... // using bf16_t = ...
using bf16x1_t = bf16_raw_t __attribute__((ext_vector_type(1)));
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2))); using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4))); using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4)));
using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8))); using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8)));
...@@ -110,6 +114,7 @@ using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64))); ...@@ -110,6 +114,7 @@ using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64)));
// i32 // i32
// using int32_t = ... // using int32_t = ...
using int32x1_t = int32_t __attribute__((ext_vector_type(1)));
using int32x2_t = int32_t __attribute__((ext_vector_type(2))); using int32x2_t = int32_t __attribute__((ext_vector_type(2)));
using int32x4_t = int32_t __attribute__((ext_vector_type(4))); using int32x4_t = int32_t __attribute__((ext_vector_type(4)));
using int32x8_t = int32_t __attribute__((ext_vector_type(8))); using int32x8_t = int32_t __attribute__((ext_vector_type(8)));
...@@ -119,6 +124,7 @@ using int32x64_t = int32_t __attribute__((ext_vector_type(64))); ...@@ -119,6 +124,7 @@ using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
// u32 // u32
// using uint32_t = ... // using uint32_t = ...
using uint32x1_t = uint32_t __attribute__((ext_vector_type(1)));
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2))); using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
using uint32x4_t = uint32_t __attribute__((ext_vector_type(4))); using uint32x4_t = uint32_t __attribute__((ext_vector_type(4)));
using uint32x8_t = uint32_t __attribute__((ext_vector_type(8))); using uint32x8_t = uint32_t __attribute__((ext_vector_type(8)));
...@@ -128,6 +134,7 @@ using uint32x64_t = uint32_t __attribute__((ext_vector_type(64))); ...@@ -128,6 +134,7 @@ using uint32x64_t = uint32_t __attribute__((ext_vector_type(64)));
// i16 // i16
// using int16_t = ... // using int16_t = ...
using int16x1_t = int16_t __attribute__((ext_vector_type(1)));
using int16x2_t = int16_t __attribute__((ext_vector_type(2))); using int16x2_t = int16_t __attribute__((ext_vector_type(2)));
using int16x4_t = int16_t __attribute__((ext_vector_type(4))); using int16x4_t = int16_t __attribute__((ext_vector_type(4)));
using int16x8_t = int16_t __attribute__((ext_vector_type(8))); using int16x8_t = int16_t __attribute__((ext_vector_type(8)));
...@@ -137,6 +144,7 @@ using int16x64_t = int16_t __attribute__((ext_vector_type(64))); ...@@ -137,6 +144,7 @@ using int16x64_t = int16_t __attribute__((ext_vector_type(64)));
// u16 // u16
// using uint16_t // using uint16_t
using uint16x1_t = uint16_t __attribute__((ext_vector_type(1)));
using uint16x2_t = uint16_t __attribute__((ext_vector_type(2))); using uint16x2_t = uint16_t __attribute__((ext_vector_type(2)));
using uint16x4_t = uint16_t __attribute__((ext_vector_type(4))); using uint16x4_t = uint16_t __attribute__((ext_vector_type(4)));
using uint16x8_t = uint16_t __attribute__((ext_vector_type(8))); using uint16x8_t = uint16_t __attribute__((ext_vector_type(8)));
...@@ -146,6 +154,7 @@ using uint16x64_t = uint16_t __attribute__((ext_vector_type(64))); ...@@ -146,6 +154,7 @@ using uint16x64_t = uint16_t __attribute__((ext_vector_type(64)));
// i8 // i8
// using int8_t // using int8_t
using int8x1_t = int8_t __attribute((ext_vector_type(1)));
using int8x2_t = int8_t __attribute((ext_vector_type(2))); using int8x2_t = int8_t __attribute((ext_vector_type(2)));
using int8x4_t = int8_t __attribute((ext_vector_type(4))); using int8x4_t = int8_t __attribute((ext_vector_type(4)));
using int8x8_t = int8_t __attribute((ext_vector_type(8))); using int8x8_t = int8_t __attribute((ext_vector_type(8)));
...@@ -155,6 +164,7 @@ using int8x64_t = int8_t __attribute((ext_vector_type(64))); ...@@ -155,6 +164,7 @@ using int8x64_t = int8_t __attribute((ext_vector_type(64)));
// ui8 // ui8
// using uint8_t // using uint8_t
using uint8x1_t = uint8_t __attribute((ext_vector_type(1)));
using uint8x2_t = uint8_t __attribute((ext_vector_type(2))); using uint8x2_t = uint8_t __attribute((ext_vector_type(2)));
using uint8x4_t = uint8_t __attribute((ext_vector_type(4))); using uint8x4_t = uint8_t __attribute((ext_vector_type(4)));
using uint8x8_t = uint8_t __attribute((ext_vector_type(8))); using uint8x8_t = uint8_t __attribute((ext_vector_type(8)));
...@@ -165,6 +175,7 @@ using uint8x64_t = uint8_t __attribute((ext_vector_type(64))); ...@@ -165,6 +175,7 @@ using uint8x64_t = uint8_t __attribute((ext_vector_type(64)));
#if CK_TILE_USE_CUSTOM_DATA_TYPE #if CK_TILE_USE_CUSTOM_DATA_TYPE
// f8 // f8
// using fp8_t // using fp8_t
using fp8x1_t = fp8_raw_t __attribute((ext_vector_type(1)));
using fp8x2_t = fp8_raw_t __attribute((ext_vector_type(2))); using fp8x2_t = fp8_raw_t __attribute((ext_vector_type(2)));
using fp8x4_t = fp8_raw_t __attribute((ext_vector_type(4))); using fp8x4_t = fp8_raw_t __attribute((ext_vector_type(4)));
using fp8x8_t = fp8_raw_t __attribute((ext_vector_type(8))); using fp8x8_t = fp8_raw_t __attribute((ext_vector_type(8)));
...@@ -174,6 +185,7 @@ using fp8x64_t = fp8_raw_t __attribute((ext_vector_type(64))); ...@@ -174,6 +185,7 @@ using fp8x64_t = fp8_raw_t __attribute((ext_vector_type(64)));
// bf8 // bf8
// using bf8_t // using bf8_t
using bf8x1_t = bf8_raw_t __attribute((ext_vector_type(1)));
using bf8x2_t = bf8_raw_t __attribute((ext_vector_type(2))); using bf8x2_t = bf8_raw_t __attribute((ext_vector_type(2)));
using bf8x4_t = bf8_raw_t __attribute((ext_vector_type(4))); using bf8x4_t = bf8_raw_t __attribute((ext_vector_type(4)));
using bf8x8_t = bf8_raw_t __attribute((ext_vector_type(8))); using bf8x8_t = bf8_raw_t __attribute((ext_vector_type(8)));
...@@ -183,6 +195,7 @@ using bf8x64_t = bf8_raw_t __attribute((ext_vector_type(64))); ...@@ -183,6 +195,7 @@ using bf8x64_t = bf8_raw_t __attribute((ext_vector_type(64)));
#else #else
// f8 // f8
// using fp8_t // using fp8_t
using fp8x1_t = fp8_t __attribute((ext_vector_type(1)));
using fp8x2_t = fp8_t __attribute((ext_vector_type(2))); using fp8x2_t = fp8_t __attribute((ext_vector_type(2)));
using fp8x4_t = fp8_t __attribute((ext_vector_type(4))); using fp8x4_t = fp8_t __attribute((ext_vector_type(4)));
using fp8x8_t = fp8_t __attribute((ext_vector_type(8))); using fp8x8_t = fp8_t __attribute((ext_vector_type(8)));
...@@ -192,6 +205,7 @@ using fp8x64_t = fp8_t __attribute((ext_vector_type(64))); ...@@ -192,6 +205,7 @@ using fp8x64_t = fp8_t __attribute((ext_vector_type(64)));
// bf8 // bf8
// using bf8_t // using bf8_t
using bf8x1_t = bf8_t __attribute((ext_vector_type(1)));
using bf8x2_t = bf8_t __attribute((ext_vector_type(2))); using bf8x2_t = bf8_t __attribute((ext_vector_type(2)));
using bf8x4_t = bf8_t __attribute((ext_vector_type(4))); using bf8x4_t = bf8_t __attribute((ext_vector_type(4)));
using bf8x8_t = bf8_t __attribute((ext_vector_type(8))); using bf8x8_t = bf8_t __attribute((ext_vector_type(8)));
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -48,6 +48,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy ...@@ -48,6 +48,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType, using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape, typename Problem::BlockShape,
true,
Problem::Traits::kFastFDiv, Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>; Problem::Traits::kWelford>;
return BlockNormReduce<P_>{}; return BlockNormReduce<P_>{};
...@@ -59,6 +60,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy ...@@ -59,6 +60,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType, using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape, typename Problem::BlockShape,
true,
Problem::Traits::kFastFDiv, Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>; Problem::Traits::kWelford>;
...@@ -71,6 +73,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy ...@@ -71,6 +73,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType, using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape, typename Problem::BlockShape,
true,
Problem::Traits::kFastFDiv, Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>; Problem::Traits::kWelford>;
...@@ -85,6 +88,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy ...@@ -85,6 +88,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType, using P_ = BlockNormReduceProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType, typename Problem::ComputeDataType,
typename Problem::BlockShape, typename Problem::BlockShape,
true,
Problem::Traits::kFastFDiv, Problem::Traits::kFastFDiv,
Problem::Traits::kWelford>; Problem::Traits::kWelford>;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -10,15 +10,17 @@ namespace ck_tile { ...@@ -10,15 +10,17 @@ namespace ck_tile {
template <typename XDataType_, template <typename XDataType_,
typename ComputeDataType_, typename ComputeDataType_,
typename BlockShape_, typename BlockShape_,
bool kComputeVariance_,
bool kFastFDiv_, bool kFastFDiv_,
bool kWelford_> bool kWelford_>
struct BlockNormReduceProblem struct BlockNormReduceProblem
{ {
using XDataType = remove_cvref_t<XDataType_>; using XDataType = remove_cvref_t<XDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>; using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kFastFDiv = kFastFDiv_; static constexpr bool kComputeVariance = kComputeVariance_;
static constexpr bool kWelford = kWelford_; static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kWelford = kWelford_;
}; };
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment